whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
56
56
  }
57
57
  }
58
58
 
59
+ template <typename reorder_vec_dot_q_sycl, int ncols_dst>
60
+ static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy,
61
+ float * __restrict__ dst, const int ncols, const int nrows,
62
+ const int stride_col_y_bytes, const int stride_col_dst,
63
+ const sycl::nd_item<3> & nd_item) {
64
+ using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
65
+ using block_traits = typename block_type::traits;
66
+
67
+ const auto sg = nd_item.get_sub_group();
68
+ const int sg_range = sg.get_group_linear_range();
69
+ const int workgroup_id = nd_item.get_group_linear_id();
70
+ const int sg_id = sg.get_group_linear_id();
71
+ const int row = workgroup_id * sg_range + sg_id;
72
+
73
+ if (row >= nrows) {
74
+ return;
75
+ }
76
+
77
+ const int blocks_per_row = ncols / block_traits::qk;
78
+ constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
79
+ constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
80
+ const int nblocks = nrows * (ncols / block_traits::qk);
81
+
82
+ static_assert(blocks_per_subgroup > 0);
83
+ static_assert(block_elements_per_subgroup > 0);
84
+
85
+ float partial_sum[ncols_dst] = {0.0f};
86
+ for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
87
+ const int ibx = row * blocks_per_row + i;
88
+
89
+ const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
90
+ const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
91
+ const int iby = i * block_type::block_to_q8_1_ratio();
92
+
93
+ #pragma unroll
94
+ for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
95
+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
96
+
97
+ #pragma unroll
98
+ for (int j = 0; j < ncols_dst; ++j) {
99
+ const char * vy_j = (const char *)vy + j * stride_col_y_bytes;
100
+ const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1;
101
+ const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2));
102
+
103
+ partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
104
+ }
105
+ }
106
+ }
107
+
108
+ #pragma unroll
109
+ for (int j = 0; j < ncols_dst; ++j) {
110
+ float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>());
111
+
112
+ if (sg.leader()) {
113
+ dst[j * stride_col_dst + row] = sum;
114
+ }
115
+ }
116
+ }
117
+
59
118
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
60
119
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
61
120
  const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
@@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
100
159
  }
101
160
  }
102
161
 
162
+ template <int qk, int qi, typename block_q_t, int vdr,
163
+ vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst>
164
+ static void mul_mat_vec_q_ncols(
165
+ const void * __restrict__ vx,
166
+ const void * __restrict__ vy,
167
+ float * __restrict__ dst,
168
+ const int ncols,
169
+ const int nrows,
170
+ const int stride_col_y,
171
+ const int stride_col_dst,
172
+ const sycl::nd_item<3> & item_ct1) {
173
+
174
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1)
175
+ + item_ct1.get_local_id(1);
176
+
177
+ if (row >= nrows) {
178
+ return;
179
+ }
180
+
181
+ const int blocks_per_row = ncols / qk;
182
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
183
+
184
+ // partial sums: one per output column
185
+ float tmp[ncols_dst] = {0.0f};
186
+
187
+ const block_q_t * x = (const block_q_t *) vx;
188
+ const block_q8_1 * y = (const block_q8_1 *) vy;
189
+
190
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr);
191
+ i < blocks_per_row;
192
+ i += blocks_per_warp) {
193
+
194
+ const int ibx = row * blocks_per_row + i;
195
+ const int iby = i * (qk / QK8_1);
196
+
197
+ // read weight block once, dot against all columns
198
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
199
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
200
+
201
+ #pragma unroll
202
+ for (int j = 0; j < ncols_dst; ++j) {
203
+ tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs);
204
+ }
205
+ }
206
+ }
207
+
208
+ // reduce within subgroup
209
+ #pragma unroll
210
+ for (int j = 0; j < ncols_dst; ++j) {
211
+ #pragma unroll
212
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
213
+ tmp[j] += dpct::permute_sub_group_by_xor(
214
+ item_ct1.get_sub_group(), tmp[j], mask);
215
+ }
216
+ }
217
+
218
+ if (item_ct1.get_local_id(2) == 0) {
219
+ #pragma unroll
220
+ for (int j = 0; j < ncols_dst; ++j) {
221
+ dst[j * stride_col_dst + row] = tmp[j];
222
+ }
223
+ }
224
+ }
225
+
103
226
  template <int qk, int qi, typename block_q_t, int vdr>
104
227
  static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
105
228
  const void *__restrict__ vy,
@@ -537,9 +660,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
537
660
  static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
538
661
  const int nrows, dpct::queue_ptr stream) {
539
662
  GGML_ASSERT(ncols % QK4_0 == 0);
540
- const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
541
- constexpr size_t num_subgroups = 16;
542
- GGML_ASSERT(block_num_y % num_subgroups == 0);
663
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
664
+ constexpr size_t num_subgroups = WARP_SIZE;
665
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
543
666
 
544
667
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
545
668
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
@@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
553
676
  });
554
677
  }
555
678
 
679
+ template <int ncols_dst>
680
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols(
681
+ const void * vx, const void * vy, float * dst,
682
+ const int ncols, const int nrows,
683
+ const int stride_col_y_bytes, const int stride_col_dst,
684
+ dpct::queue_ptr stream) {
685
+ GGML_ASSERT(ncols % QK4_0 == 0);
686
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
687
+ constexpr size_t num_subgroups = 16;
688
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
689
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
690
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
691
+ stream->submit([&](sycl::handler & cgh) {
692
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
693
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
694
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>(
695
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
696
+ });
697
+ });
698
+ }
699
+
700
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
701
+ const void * vx, const void * vy, float * dst,
702
+ const int ncols, const int nrows, const int ncols_dst,
703
+ const int stride_col_y_bytes, const int stride_col_dst,
704
+ dpct::queue_ptr stream) {
705
+ switch (ncols_dst) {
706
+ case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
707
+ case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
708
+ case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
709
+ case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
710
+ case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
711
+ case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
712
+ case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
713
+ case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
714
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst);
715
+ }
716
+ }
717
+
556
718
  static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
557
719
  dpct::queue_ptr stream) {
558
720
  GGML_ASSERT(ncols % QK4_0 == 0);
@@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
571
733
  }
572
734
  }
573
735
 
736
+ template <int ncols_dst>
737
+ static void mul_mat_vec_q4_0_q8_1_sycl_ncols(
738
+ const void * vx, const void * vy, float * dst,
739
+ const int ncols, const int nrows,
740
+ const int stride_col_y, const int stride_col_dst,
741
+ dpct::queue_ptr stream) {
742
+ GGML_ASSERT(ncols % QK4_0 == 0);
743
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
744
+ const sycl::range<3> block_nums(1, 1, block_num_y);
745
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
746
+ stream->submit([&](sycl::handler & cgh) {
747
+ cgh.parallel_for(
748
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
749
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
750
+ mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0,
751
+ VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>(
752
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
753
+ });
754
+ });
755
+ }
756
+
757
+ static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
758
+ const void * vx, const void * vy, float * dst,
759
+ const int ncols, const int nrows, const int ncols_dst,
760
+ const int stride_col_y, const int stride_col_dst,
761
+ dpct::queue_ptr stream) {
762
+ switch (ncols_dst) {
763
+ case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
764
+ case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
765
+ case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
766
+ case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
767
+ case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
768
+ case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
769
+ case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
770
+ case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
771
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst);
772
+ }
773
+ }
774
+
574
775
  static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
575
776
  float *dst, const int ncols,
576
777
  const int nrows,
@@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
595
796
  }
596
797
  }
597
798
 
799
+ template <int ncols_dst>
800
+ static void mul_mat_vec_q4_1_q8_1_sycl_ncols(
801
+ const void * vx, const void * vy, float * dst,
802
+ const int ncols, const int nrows,
803
+ const int stride_col_y, const int stride_col_dst,
804
+ dpct::queue_ptr stream) {
805
+ GGML_ASSERT(ncols % QK4_1 == 0);
806
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
807
+ const sycl::range<3> block_nums(1, 1, block_num_y);
808
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
809
+ stream->submit([&](sycl::handler & cgh) {
810
+ cgh.parallel_for(
811
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
812
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
813
+ mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1,
814
+ VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>(
815
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
816
+ });
817
+ });
818
+ }
819
+
820
+ static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
821
+ const void * vx, const void * vy, float * dst,
822
+ const int ncols, const int nrows, const int ncols_dst,
823
+ const int stride_col_y, const int stride_col_dst,
824
+ dpct::queue_ptr stream) {
825
+ switch (ncols_dst) {
826
+ case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
827
+ case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
828
+ case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
829
+ case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
830
+ case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
831
+ case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
832
+ case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
833
+ case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
834
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst);
835
+ }
836
+ }
837
+
598
838
  static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
599
839
  dpct::queue_ptr stream) {
600
840
  GGML_ASSERT(ncols % QK_MXFP4 == 0);
@@ -613,6 +853,101 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
613
853
  }
614
854
  }
615
855
 
856
+ template <int ncols_dst>
857
+ static void mul_mat_vec_mxfp4_q8_1_sycl_ncols(
858
+ const void * vx, const void * vy, float * dst,
859
+ const int ncols, const int nrows,
860
+ const int stride_col_y, const int stride_col_dst,
861
+ dpct::queue_ptr stream) {
862
+ GGML_ASSERT(ncols % QK_MXFP4 == 0);
863
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
864
+ const sycl::range<3> block_nums(1, 1, block_num_y);
865
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
866
+ stream->submit([&](sycl::handler & cgh) {
867
+ cgh.parallel_for(
868
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
869
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
870
+ mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4,
871
+ VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>(
872
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
873
+ });
874
+ });
875
+ }
876
+
877
+ static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
878
+ const void * vx, const void * vy, float * dst,
879
+ const int ncols, const int nrows, const int ncols_dst,
880
+ const int stride_col_y, const int stride_col_dst,
881
+ dpct::queue_ptr stream) {
882
+ switch (ncols_dst) {
883
+ case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
884
+ case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
885
+ case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
886
+ case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
887
+ case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
888
+ case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
889
+ case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
890
+ case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
891
+ default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst);
892
+ }
893
+ }
894
+
895
+ static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
896
+ dpct::queue_ptr stream) {
897
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
898
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
899
+ const sycl::range<3> block_nums(1, 1, block_num_y);
900
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
901
+
902
+ {
903
+ stream->submit([&](sycl::handler & cgh) {
904
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
+ mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
907
+ vx, vy, dst, ncols, nrows, item_ct1);
908
+ });
909
+ });
910
+ }
911
+ }
912
+
913
+ template <int ncols_dst>
914
+ static void mul_mat_vec_nvfp4_q8_1_sycl_ncols(
915
+ const void * vx, const void * vy, float * dst,
916
+ const int ncols, const int nrows,
917
+ const int stride_col_y, const int stride_col_dst,
918
+ dpct::queue_ptr stream) {
919
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
920
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
921
+ const sycl::range<3> block_nums(1, 1, block_num_y);
922
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
923
+ stream->submit([&](sycl::handler & cgh) {
924
+ cgh.parallel_for(
925
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
+ mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4,
928
+ VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>(
929
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
930
+ });
931
+ });
932
+ }
933
+
934
+ static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
935
+ const void * vx, const void * vy, float * dst,
936
+ const int ncols, const int nrows, const int ncols_dst,
937
+ const int stride_col_y, const int stride_col_dst,
938
+ dpct::queue_ptr stream) {
939
+ switch (ncols_dst) {
940
+ case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
941
+ case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
942
+ case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
943
+ case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
944
+ case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
945
+ case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
946
+ case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
947
+ case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
948
+ default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst);
949
+ }
950
+ }
616
951
 
617
952
  static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
618
953
  float *dst, const int ncols,
@@ -638,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
638
973
  }
639
974
  }
640
975
 
976
+ template <int ncols_dst>
977
+ static void mul_mat_vec_q5_0_q8_1_sycl_ncols(
978
+ const void * vx, const void * vy, float * dst,
979
+ const int ncols, const int nrows,
980
+ const int stride_col_y, const int stride_col_dst,
981
+ dpct::queue_ptr stream) {
982
+ GGML_ASSERT(ncols % QK5_0 == 0);
983
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
984
+ const sycl::range<3> block_nums(1, 1, block_num_y);
985
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
986
+ stream->submit([&](sycl::handler & cgh) {
987
+ cgh.parallel_for(
988
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
+ mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0,
991
+ VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>(
992
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
993
+ });
994
+ });
995
+ }
996
+
997
+ static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
998
+ const void * vx, const void * vy, float * dst,
999
+ const int ncols, const int nrows, const int ncols_dst,
1000
+ const int stride_col_y, const int stride_col_dst,
1001
+ dpct::queue_ptr stream) {
1002
+ switch (ncols_dst) {
1003
+ case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1004
+ case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1005
+ case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1006
+ case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1007
+ case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1008
+ case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1009
+ case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1010
+ case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1011
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst);
1012
+ }
1013
+ }
1014
+
641
1015
  static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
642
1016
  float *dst, const int ncols,
643
1017
  const int nrows,
@@ -662,6 +1036,103 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
662
1036
  }
663
1037
  }
664
1038
 
1039
+ template <int ncols_dst>
1040
+ static void mul_mat_vec_q5_1_q8_1_sycl_ncols(
1041
+ const void * vx, const void * vy, float * dst,
1042
+ const int ncols, const int nrows,
1043
+ const int stride_col_y, const int stride_col_dst,
1044
+ dpct::queue_ptr stream) {
1045
+ GGML_ASSERT(ncols % QK5_1 == 0);
1046
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1047
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1048
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1049
+ stream->submit([&](sycl::handler & cgh) {
1050
+ cgh.parallel_for(
1051
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1052
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1053
+ mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1,
1054
+ VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>(
1055
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1056
+ });
1057
+ });
1058
+ }
1059
+
1060
+ static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
1061
+ const void * vx, const void * vy, float * dst,
1062
+ const int ncols, const int nrows, const int ncols_dst,
1063
+ const int stride_col_y, const int stride_col_dst,
1064
+ dpct::queue_ptr stream) {
1065
+ switch (ncols_dst) {
1066
+ case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1067
+ case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1068
+ case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1069
+ case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1070
+ case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1071
+ case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1072
+ case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1073
+ case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1074
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst);
1075
+ }
1076
+ }
1077
+
1078
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1079
+ const int nrows, dpct::queue_ptr stream) {
1080
+ GGML_ASSERT(ncols % QK8_0 == 0);
1081
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1082
+ constexpr size_t num_subgroups = WARP_SIZE;
1083
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1084
+
1085
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
1086
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1087
+
1088
+ stream->submit([&](sycl::handler & cgh) {
1089
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1090
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1091
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
1092
+ nd_item);
1093
+ });
1094
+ });
1095
+ }
1096
+
1097
+ template <int ncols_dst>
1098
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols(
1099
+ const void * vx, const void * vy, float * dst,
1100
+ const int ncols, const int nrows,
1101
+ const int stride_col_y_bytes, const int stride_col_dst,
1102
+ dpct::queue_ptr stream) {
1103
+ GGML_ASSERT(ncols % QK8_0 == 0);
1104
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1105
+ constexpr size_t num_subgroups = 16;
1106
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1107
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1108
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1109
+ stream->submit([&](sycl::handler & cgh) {
1110
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1111
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1112
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>(
1113
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1114
+ });
1115
+ });
1116
+ }
1117
+
1118
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
1119
+ const void * vx, const void * vy, float * dst,
1120
+ const int ncols, const int nrows, const int ncols_dst,
1121
+ const int stride_col_y_bytes, const int stride_col_dst,
1122
+ dpct::queue_ptr stream) {
1123
+ switch (ncols_dst) {
1124
+ case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1125
+ case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1126
+ case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1127
+ case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1128
+ case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1129
+ case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1130
+ case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1131
+ case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1132
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst);
1133
+ }
1134
+ }
1135
+
665
1136
  static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
666
1137
  float *dst, const int ncols,
667
1138
  const int nrows,
@@ -686,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
686
1157
  }
687
1158
  }
688
1159
 
1160
+ template <int ncols_dst>
1161
+ static void mul_mat_vec_q8_0_q8_1_sycl_ncols(
1162
+ const void * vx, const void * vy, float * dst,
1163
+ const int ncols, const int nrows,
1164
+ const int stride_col_y, const int stride_col_dst,
1165
+ dpct::queue_ptr stream) {
1166
+ GGML_ASSERT(ncols % QK8_0 == 0);
1167
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1168
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1169
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1170
+ stream->submit([&](sycl::handler & cgh) {
1171
+ cgh.parallel_for(
1172
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1173
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1174
+ mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0,
1175
+ VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>(
1176
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1177
+ });
1178
+ });
1179
+ }
1180
+
1181
+ static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
1182
+ const void * vx, const void * vy, float * dst,
1183
+ const int ncols, const int nrows, const int ncols_dst,
1184
+ const int stride_col_y, const int stride_col_dst,
1185
+ dpct::queue_ptr stream) {
1186
+ switch (ncols_dst) {
1187
+ case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1188
+ case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1189
+ case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1190
+ case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1191
+ case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1192
+ case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1193
+ case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1194
+ case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1195
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst);
1196
+ }
1197
+ }
1198
+
689
1199
  static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
690
1200
  float *dst, const int ncols,
691
1201
  const int nrows,
@@ -710,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
710
1220
  }
711
1221
  }
712
1222
 
1223
+ template <int ncols_dst>
1224
+ static void mul_mat_vec_q2_K_q8_1_sycl_ncols(
1225
+ const void * vx, const void * vy, float * dst,
1226
+ const int ncols, const int nrows,
1227
+ const int stride_col_y, const int stride_col_dst,
1228
+ dpct::queue_ptr stream) {
1229
+ GGML_ASSERT(ncols % QK_K == 0);
1230
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1231
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1232
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1233
+ stream->submit([&](sycl::handler & cgh) {
1234
+ cgh.parallel_for(
1235
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1236
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1237
+ mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K,
1238
+ VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>(
1239
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1240
+ });
1241
+ });
1242
+ }
1243
+
1244
+ static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
1245
+ const void * vx, const void * vy, float * dst,
1246
+ const int ncols, const int nrows, const int ncols_dst,
1247
+ const int stride_col_y, const int stride_col_dst,
1248
+ dpct::queue_ptr stream) {
1249
+ switch (ncols_dst) {
1250
+ case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1251
+ case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1252
+ case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1253
+ case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1254
+ case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1255
+ case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1256
+ case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1257
+ case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1258
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst);
1259
+ }
1260
+ }
1261
+
713
1262
  static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
714
1263
  float *dst, const int ncols,
715
1264
  const int nrows,
@@ -734,6 +1283,105 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
734
1283
  }
735
1284
  }
736
1285
 
1286
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1287
+ const int nrows, dpct::queue_ptr stream) {
1288
+ GGML_ASSERT(ncols % QK_K == 0);
1289
+
1290
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1291
+ constexpr size_t num_subgroups = WARP_SIZE;
1292
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1293
+
1294
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1295
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1296
+
1297
+ stream->submit([&](sycl::handler & cgh) {
1298
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1299
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1300
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows,
1301
+ nd_item);
1302
+ });
1303
+ });
1304
+ }
1305
+
1306
+ template <int ncols_dst>
1307
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols(
1308
+ const void * vx, const void * vy, float * dst,
1309
+ const int ncols, const int nrows,
1310
+ const int stride_col_y_bytes, const int stride_col_dst,
1311
+ dpct::queue_ptr stream) {
1312
+ GGML_ASSERT(ncols % QK_K == 0);
1313
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1314
+ constexpr size_t num_subgroups = 16;
1315
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1316
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1317
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1318
+ stream->submit([&](sycl::handler & cgh) {
1319
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1320
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1321
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>(
1322
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1323
+ });
1324
+ });
1325
+ }
1326
+
1327
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
1328
+ const void * vx, const void * vy, float * dst,
1329
+ const int ncols, const int nrows, const int ncols_dst,
1330
+ const int stride_col_y_bytes, const int stride_col_dst,
1331
+ dpct::queue_ptr stream) {
1332
+ switch (ncols_dst) {
1333
+ case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1334
+ case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1335
+ case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1336
+ case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1337
+ case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1338
+ case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1339
+ case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1340
+ case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1341
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst);
1342
+ }
1343
+ }
1344
+
1345
+ template <int ncols_dst>
1346
+ static void mul_mat_vec_q3_K_q8_1_sycl_ncols(
1347
+ const void * vx, const void * vy, float * dst,
1348
+ const int ncols, const int nrows,
1349
+ const int stride_col_y, const int stride_col_dst,
1350
+ dpct::queue_ptr stream) {
1351
+ GGML_ASSERT(ncols % QK_K == 0);
1352
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1353
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1354
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1355
+ stream->submit([&](sycl::handler & cgh) {
1356
+ cgh.parallel_for(
1357
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1358
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1359
+ mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K,
1360
+ VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>(
1361
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1362
+ });
1363
+ });
1364
+ }
1365
+
1366
+ static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
1367
+ const void * vx, const void * vy, float * dst,
1368
+ const int ncols, const int nrows, const int ncols_dst,
1369
+ const int stride_col_y, const int stride_col_dst,
1370
+ dpct::queue_ptr stream) {
1371
+ switch (ncols_dst) {
1372
+ case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1373
+ case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1374
+ case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1375
+ case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1376
+ case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1377
+ case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1378
+ case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1379
+ case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1380
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst);
1381
+ }
1382
+ }
1383
+
1384
+
737
1385
  static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
738
1386
  float *dst, const int ncols,
739
1387
  const int nrows,
@@ -758,13 +1406,58 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
758
1406
  }
759
1407
  }
760
1408
 
1409
+ template <int ncols_dst>
1410
+ static void mul_mat_vec_q4_K_q8_1_sycl_ncols(
1411
+ const void * vx, const void * vy, float * dst,
1412
+ const int ncols, const int nrows,
1413
+ const int stride_col_y, const int stride_col_dst,
1414
+ dpct::queue_ptr stream) {
1415
+ GGML_ASSERT(ncols % QK_K == 0);
1416
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1417
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1418
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1419
+
1420
+ stream->submit([&](sycl::handler & cgh) {
1421
+ cgh.parallel_for(
1422
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1423
+ [=](sycl::nd_item<3> item_ct1)
1424
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1425
+ mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K,
1426
+ VDR_Q4_K_Q8_1_MMVQ,
1427
+ vec_dot_q4_K_q8_1,
1428
+ ncols_dst>(
1429
+ vx, vy, dst, ncols, nrows,
1430
+ stride_col_y, stride_col_dst, item_ct1);
1431
+ });
1432
+ });
1433
+ }
1434
+
1435
+ static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
1436
+ const void * vx, const void * vy, float * dst,
1437
+ const int ncols, const int nrows,
1438
+ const int ncols_dst,
1439
+ const int stride_col_y, const int stride_col_dst,
1440
+ dpct::queue_ptr stream) {
1441
+ switch (ncols_dst) {
1442
+ case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1443
+ case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1444
+ case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1445
+ case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1446
+ case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1447
+ case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1448
+ case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1449
+ case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1450
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst);
1451
+ }
1452
+ }
1453
+
761
1454
  static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
762
1455
  const int nrows, dpct::queue_ptr stream) {
763
1456
  GGML_ASSERT(ncols % QK_K == 0);
764
1457
 
765
- const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
766
- constexpr size_t num_subgroups = 16;
767
- GGML_ASSERT(block_num_y % num_subgroups == 0);
1458
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1459
+ constexpr size_t num_subgroups = WARP_SIZE;
1460
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
768
1461
 
769
1462
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
770
1463
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
@@ -778,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
778
1471
  });
779
1472
  }
780
1473
 
1474
+ template <int ncols_dst>
1475
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols(
1476
+ const void * vx, const void * vy, float * dst,
1477
+ const int ncols, const int nrows,
1478
+ const int stride_col_y_bytes, const int stride_col_dst,
1479
+ dpct::queue_ptr stream) {
1480
+ GGML_ASSERT(ncols % QK_K == 0);
1481
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1482
+ constexpr size_t num_subgroups = 16;
1483
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1484
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1485
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1486
+ stream->submit([&](sycl::handler & cgh) {
1487
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1488
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1489
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>(
1490
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1491
+ });
1492
+ });
1493
+ }
1494
+
1495
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
1496
+ const void * vx, const void * vy, float * dst,
1497
+ const int ncols, const int nrows, const int ncols_dst,
1498
+ const int stride_col_y_bytes, const int stride_col_dst,
1499
+ dpct::queue_ptr stream) {
1500
+ switch (ncols_dst) {
1501
+ case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1502
+ case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1503
+ case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1504
+ case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1505
+ case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1506
+ case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1507
+ case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1508
+ case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1509
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst);
1510
+ }
1511
+ }
781
1512
 
782
1513
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
783
1514
  float *dst, const int ncols,
@@ -803,9 +1534,55 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
803
1534
  }
804
1535
  }
805
1536
 
806
- static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1537
+ template <int ncols_dst>
1538
+ static void mul_mat_vec_q5_K_q8_1_sycl_ncols(
1539
+ const void * vx, const void * vy, float * dst,
1540
+ const int ncols, const int nrows,
1541
+ const int stride_col_y, const int stride_col_dst,
1542
+ dpct::queue_ptr stream) {
1543
+ GGML_ASSERT(ncols % QK_K == 0);
1544
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1545
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1546
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1547
+
1548
+ stream->submit([&](sycl::handler & cgh) {
1549
+ cgh.parallel_for(
1550
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1551
+ [=](sycl::nd_item<3> item_ct1)
1552
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1553
+ mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K,
1554
+ VDR_Q5_K_Q8_1_MMVQ,
1555
+ vec_dot_q5_K_q8_1,
1556
+ ncols_dst>(
1557
+ vx, vy, dst, ncols, nrows,
1558
+ stride_col_y, stride_col_dst, item_ct1);
1559
+ });
1560
+ });
1561
+ }
1562
+
1563
+ static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
1564
+ const void * vx, const void * vy, float * dst,
1565
+ const int ncols, const int nrows,
1566
+ const int ncols_dst,
1567
+ const int stride_col_y, const int stride_col_dst,
1568
+ dpct::queue_ptr stream) {
1569
+ switch (ncols_dst) {
1570
+ case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1571
+ case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1572
+ case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1573
+ case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1574
+ case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1575
+ case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1576
+ case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1577
+ case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1578
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst);
1579
+ }
1580
+ }
1581
+
1582
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
807
1583
  const int nrows, dpct::queue_ptr stream) {
808
1584
  GGML_ASSERT(ncols % QK_K == 0);
1585
+
809
1586
  const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
810
1587
  constexpr size_t num_subgroups = 16;
811
1588
  GGML_ASSERT(block_num_y % num_subgroups == 0);
@@ -813,6 +1590,64 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
813
1590
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
814
1591
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
815
1592
 
1593
+ stream->submit([&](sycl::handler & cgh) {
1594
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1595
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1596
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols,
1597
+ nrows, nd_item);
1598
+ });
1599
+ });
1600
+ }
1601
+
1602
+ template <int ncols_dst>
1603
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols(
1604
+ const void * vx, const void * vy, float * dst,
1605
+ const int ncols, const int nrows,
1606
+ const int stride_col_y_bytes, const int stride_col_dst,
1607
+ dpct::queue_ptr stream) {
1608
+ GGML_ASSERT(ncols % QK_K == 0);
1609
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1610
+ constexpr size_t num_subgroups = 16;
1611
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1612
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1613
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1614
+ stream->submit([&](sycl::handler & cgh) {
1615
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1616
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1617
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>(
1618
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1619
+ });
1620
+ });
1621
+ }
1622
+
1623
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
1624
+ const void * vx, const void * vy, float * dst,
1625
+ const int ncols, const int nrows, const int ncols_dst,
1626
+ const int stride_col_y_bytes, const int stride_col_dst,
1627
+ dpct::queue_ptr stream) {
1628
+ switch (ncols_dst) {
1629
+ case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1630
+ case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1631
+ case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1632
+ case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1633
+ case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1634
+ case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1635
+ case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1636
+ case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1637
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst);
1638
+ }
1639
+ }
1640
+
1641
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1642
+ const int nrows, dpct::queue_ptr stream) {
1643
+ GGML_ASSERT(ncols % QK_K == 0);
1644
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1645
+ constexpr size_t num_subgroups = WARP_SIZE;
1646
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1647
+
1648
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1649
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1650
+
816
1651
  stream->submit([&](sycl::handler & cgh) {
817
1652
  cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
818
1653
  [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
@@ -821,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
821
1656
  });
822
1657
  });
823
1658
  }
1659
+
1660
+ template <int ncols_dst>
1661
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols(
1662
+ const void * vx, const void * vy, float * dst,
1663
+ const int ncols, const int nrows,
1664
+ const int stride_col_y_bytes, const int stride_col_dst,
1665
+ dpct::queue_ptr stream) {
1666
+ GGML_ASSERT(ncols % QK_K == 0);
1667
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1668
+ constexpr size_t num_subgroups = 16;
1669
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1670
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1671
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1672
+ stream->submit([&](sycl::handler & cgh) {
1673
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1674
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1675
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>(
1676
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1677
+ });
1678
+ });
1679
+ }
1680
+
1681
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
1682
+ const void * vx, const void * vy, float * dst,
1683
+ const int ncols, const int nrows, const int ncols_dst,
1684
+ const int stride_col_y_bytes, const int stride_col_dst,
1685
+ dpct::queue_ptr stream) {
1686
+ switch (ncols_dst) {
1687
+ case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1688
+ case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1689
+ case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1690
+ case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1691
+ case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1692
+ case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1693
+ case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1694
+ case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1695
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst);
1696
+ }
1697
+ }
1698
+
824
1699
  static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
825
1700
  float *dst, const int ncols,
826
1701
  const int nrows,
@@ -845,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
845
1720
  }
846
1721
  }
847
1722
 
1723
+ template <int ncols_dst>
1724
+ static void mul_mat_vec_q6_K_q8_1_sycl_ncols(
1725
+ const void * vx, const void * vy, float * dst,
1726
+ const int ncols, const int nrows,
1727
+ const int stride_col_y, const int stride_col_dst,
1728
+ dpct::queue_ptr stream) {
1729
+ GGML_ASSERT(ncols % QK_K == 0);
1730
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1731
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1732
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1733
+
1734
+ stream->submit([&](sycl::handler & cgh) {
1735
+ cgh.parallel_for(
1736
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1737
+ [=](sycl::nd_item<3> item_ct1)
1738
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1739
+ mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K,
1740
+ VDR_Q6_K_Q8_1_MMVQ,
1741
+ vec_dot_q6_K_q8_1,
1742
+ ncols_dst>(
1743
+ vx, vy, dst, ncols, nrows,
1744
+ stride_col_y, stride_col_dst, item_ct1);
1745
+ });
1746
+ });
1747
+ }
1748
+
1749
+ static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
1750
+ const void * vx, const void * vy, float * dst,
1751
+ const int ncols, const int nrows,
1752
+ const int ncols_dst,
1753
+ const int stride_col_y, const int stride_col_dst,
1754
+ dpct::queue_ptr stream) {
1755
+ switch (ncols_dst) {
1756
+ case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1757
+ case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1758
+ case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1759
+ case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1760
+ case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1761
+ case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1762
+ case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1763
+ case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1764
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst);
1765
+ }
1766
+ }
1767
+
848
1768
 
849
1769
  static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
850
1770
  float *dst, const int ncols,
@@ -1041,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
1041
1961
  }
1042
1962
  }
1043
1963
 
1964
+ template <int ncols_dst>
1965
+ static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols(
1966
+ const void * vx, const void * vy, float * dst,
1967
+ const int ncols, const int nrows,
1968
+ const int stride_col_y, const int stride_col_dst,
1969
+ dpct::queue_ptr stream) {
1970
+ GGML_ASSERT(ncols % QK_K == 0);
1971
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1972
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1973
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1974
+
1975
+ stream->submit([&](sycl::handler & cgh) {
1976
+ cgh.parallel_for(
1977
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1978
+ [=](sycl::nd_item<3> item_ct1)
1979
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1980
+ mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs,
1981
+ 1,
1982
+ vec_dot_iq4_xs_q8_1,
1983
+ ncols_dst>(
1984
+ vx, vy, dst, ncols, nrows,
1985
+ stride_col_y, stride_col_dst, item_ct1);
1986
+ });
1987
+ });
1988
+ }
1989
+
1990
+ static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
1991
+ const void * vx, const void * vy, float * dst,
1992
+ const int ncols, const int nrows,
1993
+ const int ncols_dst,
1994
+ const int stride_col_y, const int stride_col_dst,
1995
+ dpct::queue_ptr stream) {
1996
+ switch (ncols_dst) {
1997
+ case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1998
+ case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1999
+ case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2000
+ case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2001
+ case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2002
+ case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2003
+ case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2004
+ case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2005
+ default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst);
2006
+ }
2007
+ }
2008
+
1044
2009
  void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1045
2010
  ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1046
2011
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
@@ -1067,50 +2032,219 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1067
2032
  case GGML_TYPE_Q4_0:
1068
2033
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1069
2034
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1070
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
1071
- reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1072
- } else {
2035
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2036
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2037
+ const int stride_col_dst = dst->ne[0];
2038
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2039
+ reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
2040
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2041
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2042
+ return;
2043
+ } else {
2044
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
2045
+ reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2046
+ }
2047
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2048
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2049
+ const int stride_col_dst = dst->ne[0];
2050
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2051
+ mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
2052
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2053
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2054
+ return;
2055
+ } else if (i == 0 || src1_ncols == 1) {
1073
2056
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
1074
2057
  mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1075
2058
  }
1076
2059
  break;
1077
2060
  case GGML_TYPE_Q4_1:
1078
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2061
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2062
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2063
+ const int stride_col_dst = dst->ne[0];
2064
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2065
+ mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
2066
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2067
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2068
+ return;
2069
+ } else if (i == 0 || src1_ncols == 1) {
2070
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2071
+ }
1079
2072
  break;
1080
2073
  case GGML_TYPE_Q5_0:
1081
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2074
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2075
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2076
+ const int stride_col_dst = dst->ne[0];
2077
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2078
+ mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
2079
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2080
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2081
+ return;
2082
+ } else if (i == 0 || src1_ncols == 1) {
2083
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2084
+ }
1082
2085
  break;
1083
2086
  case GGML_TYPE_Q5_1:
1084
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2087
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2088
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2089
+ const int stride_col_dst = dst->ne[0];
2090
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2091
+ mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
2092
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2093
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2094
+ return;
2095
+ } else if (i == 0 || src1_ncols == 1) {
2096
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2097
+ }
1085
2098
  break;
1086
2099
  case GGML_TYPE_Q8_0:
1087
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2100
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2101
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2102
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2103
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2104
+ const int stride_col_dst = dst->ne[0];
2105
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2106
+ reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
2107
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2108
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2109
+ return;
2110
+ } else {
2111
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
2112
+ reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2113
+ }
2114
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2115
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2116
+ const int stride_col_dst = dst->ne[0];
2117
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2118
+ mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
2119
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2120
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2121
+ return;
2122
+ } else if (i == 0 || src1_ncols == 1) {
2123
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n");
2124
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2125
+ }
1088
2126
  break;
1089
2127
  case GGML_TYPE_Q2_K:
1090
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2128
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2129
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2130
+ const int stride_col_dst = dst->ne[0];
2131
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2132
+ mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
2133
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2134
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2135
+ return;
2136
+ } else if (i == 0 || src1_ncols == 1) {
2137
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2138
+ }
1091
2139
  break;
1092
2140
  case GGML_TYPE_Q3_K:
1093
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2141
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2142
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2143
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2144
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2145
+ const int stride_col_dst = dst->ne[0];
2146
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2147
+ reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
2148
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2149
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2150
+ return;
2151
+ } else {
2152
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n");
2153
+ reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2154
+ }
2155
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2156
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2157
+ const int stride_col_dst = dst->ne[0];
2158
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2159
+ mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
2160
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2161
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2162
+ return;
2163
+ } else if (i == 0 || src1_ncols == 1) {
2164
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n");
2165
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2166
+ }
1094
2167
  break;
1095
2168
  case GGML_TYPE_Q4_K:
1096
2169
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1097
2170
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1098
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
1099
- reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1100
- } else {
2171
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2172
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2173
+ const int stride_col_dst = dst->ne[0];
2174
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2175
+ reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
2176
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2177
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2178
+ return;
2179
+ } else {
2180
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
2181
+ reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2182
+ }
2183
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2184
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2185
+ const int stride_col_dst = dst->ne[0];
2186
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2187
+ mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
2188
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2189
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2190
+ return;
2191
+ } else if (i == 0 || src1_ncols == 1) {
1101
2192
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n");
1102
2193
  mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1103
2194
  }
1104
2195
  break;
1105
2196
  case GGML_TYPE_Q5_K:
1106
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2197
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2198
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2199
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2200
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2201
+ const int stride_col_dst = dst->ne[0];
2202
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2203
+ reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
2204
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2205
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2206
+ return;
2207
+ } else {
2208
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n");
2209
+ reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2210
+ }
2211
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2212
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2213
+ const int stride_col_dst = dst->ne[0];
2214
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2215
+ mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
2216
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2217
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2218
+ return;
2219
+ } else if (i == 0 || src1_ncols == 1) {
2220
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n");
2221
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2222
+ }
1107
2223
  break;
1108
2224
  case GGML_TYPE_Q6_K:
1109
2225
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1110
2226
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1111
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1112
- reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1113
- } else {
2227
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2228
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2229
+ const int stride_col_dst = dst->ne[0];
2230
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2231
+ reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
2232
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2233
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2234
+ return;
2235
+ } else {
2236
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
2237
+ reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2238
+ }
2239
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2240
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2241
+ const int stride_col_dst = dst->ne[0];
2242
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2243
+ mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
2244
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2245
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2246
+ return;
2247
+ } else if (i == 0 || src1_ncols == 1) {
1114
2248
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1115
2249
  mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1116
2250
  }
@@ -1140,13 +2274,46 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1140
2274
  mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1141
2275
  break;
1142
2276
  case GGML_TYPE_IQ4_XS:
1143
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2277
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2278
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2279
+ const int stride_col_dst = dst->ne[0];
2280
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2281
+ mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
2282
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2283
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2284
+ return;
2285
+ } else if (i == 0 || src1_ncols == 1) {
2286
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2287
+ }
1144
2288
  break;
1145
2289
  case GGML_TYPE_MXFP4:
1146
- mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2290
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2291
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2292
+ const int stride_col_dst = dst->ne[0];
2293
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2294
+ mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
2295
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2296
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2297
+ return;
2298
+ } else if (i == 0 || src1_ncols == 1) {
2299
+ mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2300
+ }
2301
+ break;
2302
+ case GGML_TYPE_NVFP4:
2303
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2304
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2305
+ const int stride_col_dst = dst->ne[0];
2306
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2307
+ mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
2308
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2309
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2310
+ return;
2311
+ } else if (i == 0 || src1_ncols == 1) {
2312
+ mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2313
+ }
1147
2314
  break;
1148
2315
  default:
1149
- GGML_ABORT("fatal error");
2316
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
1150
2317
  }
1151
2318
  }
1152
2319
  GGML_UNUSED(src1);
@@ -1154,3 +2321,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1154
2321
  GGML_UNUSED(src1_ddf_i);
1155
2322
  GGML_UNUSED(ctx);
1156
2323
  }
2324
+
2325
+ // src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj).
2326
+ template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
2327
+ static void mul_mat_vec_q_moe(
2328
+ const void * __restrict__ vx_base, const void * __restrict__ vy_base,
2329
+ float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev,
2330
+ const int ncols, const int nrows,
2331
+ const size_t expert_weight_stride, const size_t dst_row_stride,
2332
+ const size_t src1_row_stride,
2333
+ const sycl::nd_item<3> & item_ct1) {
2334
+
2335
+ const int expert_idx = item_ct1.get_group(1);
2336
+ const int i02 = ids_dev[expert_idx];
2337
+
2338
+ const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride;
2339
+ const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride;
2340
+ float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride);
2341
+
2342
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2343
+
2344
+ if (row >= nrows) {
2345
+ return;
2346
+ }
2347
+
2348
+ const int blocks_per_row = ncols / qk;
2349
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
2350
+
2351
+ float tmp = 0.0f;
2352
+
2353
+ const block_q_t * x = (const block_q_t *) vx;
2354
+ const block_q8_1 * y = (const block_q8_1 *) vy;
2355
+
2356
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
2357
+ const int ibx = row * blocks_per_row + i;
2358
+ const int iby = i * (qk / QK8_1);
2359
+
2360
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
2361
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
2362
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
2363
+ }
2364
+ }
2365
+
2366
+ #pragma unroll
2367
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
2368
+ tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
2369
+ }
2370
+
2371
+ if (item_ct1.get_local_id(2) == 0) {
2372
+ dst[row] = tmp;
2373
+ }
2374
+ }
2375
+
2376
+ template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
2377
+ static void launch_mul_mat_vec_q_moe(
2378
+ const void * vx_base, const void * vy, const int32_t * ids_dev,
2379
+ float * dst_base, const int ncols, const int nrows, const int n_experts_used,
2380
+ const size_t expert_weight_stride, const size_t dst_row_stride,
2381
+ const size_t src1_row_stride,
2382
+ dpct::queue_ptr stream) {
2383
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
2384
+ const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y);
2385
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
2386
+ stream->submit([&](sycl::handler & cgh) {
2387
+ cgh.parallel_for(
2388
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2389
+ [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2390
+ mul_mat_vec_q_moe<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
2391
+ vx_base, vy, dst_base, ids_dev, ncols, nrows,
2392
+ expert_weight_stride, dst_row_stride, src1_row_stride, item);
2393
+ });
2394
+ });
2395
+ }
2396
+
2397
+ bool ggml_sycl_mul_mat_vec_q_id(
2398
+ enum ggml_type src0_type,
2399
+ const void * vx_base,
2400
+ const void * vy,
2401
+ const int32_t * ids_dev,
2402
+ float * dst_base,
2403
+ int ncols,
2404
+ int nrows,
2405
+ int n_experts_used,
2406
+ size_t expert_weight_stride,
2407
+ size_t dst_row_stride,
2408
+ size_t src1_row_stride,
2409
+ dpct::queue_ptr stream) {
2410
+ switch (src0_type) {
2411
+ case GGML_TYPE_Q4_0:
2412
+ launch_mul_mat_vec_q_moe<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
2413
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2414
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2415
+ return true;
2416
+ case GGML_TYPE_Q4_1:
2417
+ launch_mul_mat_vec_q_moe<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
2418
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2419
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2420
+ return true;
2421
+ case GGML_TYPE_Q5_0:
2422
+ launch_mul_mat_vec_q_moe<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
2423
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2424
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2425
+ return true;
2426
+ case GGML_TYPE_Q5_1:
2427
+ launch_mul_mat_vec_q_moe<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
2428
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2429
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2430
+ return true;
2431
+ case GGML_TYPE_Q8_0:
2432
+ launch_mul_mat_vec_q_moe<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
2433
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2434
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2435
+ return true;
2436
+ case GGML_TYPE_Q2_K:
2437
+ launch_mul_mat_vec_q_moe<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
2438
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2439
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2440
+ return true;
2441
+ case GGML_TYPE_Q3_K:
2442
+ launch_mul_mat_vec_q_moe<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
2443
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2444
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2445
+ return true;
2446
+ case GGML_TYPE_Q4_K:
2447
+ launch_mul_mat_vec_q_moe<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
2448
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2449
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2450
+ return true;
2451
+ case GGML_TYPE_Q5_K:
2452
+ launch_mul_mat_vec_q_moe<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
2453
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2454
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2455
+ return true;
2456
+ case GGML_TYPE_Q6_K:
2457
+ launch_mul_mat_vec_q_moe<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
2458
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2459
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2460
+ return true;
2461
+ case GGML_TYPE_MXFP4:
2462
+ launch_mul_mat_vec_q_moe<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
2463
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2464
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2465
+ return true;
2466
+ case GGML_TYPE_NVFP4:
2467
+ launch_mul_mat_vec_q_moe<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
2468
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2469
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2470
+ return true;
2471
+ default:
2472
+ return false;
2473
+ }
2474
+ }