whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
56
56
  }
57
57
  }
58
58
 
59
+ template <typename reorder_vec_dot_q_sycl, int ncols_dst>
60
+ static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy,
61
+ float * __restrict__ dst, const int ncols, const int nrows,
62
+ const int stride_col_y_bytes, const int stride_col_dst,
63
+ const sycl::nd_item<3> & nd_item) {
64
+ using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
65
+ using block_traits = typename block_type::traits;
66
+
67
+ const auto sg = nd_item.get_sub_group();
68
+ const int sg_range = sg.get_group_linear_range();
69
+ const int workgroup_id = nd_item.get_group_linear_id();
70
+ const int sg_id = sg.get_group_linear_id();
71
+ const int row = workgroup_id * sg_range + sg_id;
72
+
73
+ if (row >= nrows) {
74
+ return;
75
+ }
76
+
77
+ const int blocks_per_row = ncols / block_traits::qk;
78
+ constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
79
+ constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
80
+ const int nblocks = nrows * (ncols / block_traits::qk);
81
+
82
+ static_assert(blocks_per_subgroup > 0);
83
+ static_assert(block_elements_per_subgroup > 0);
84
+
85
+ float partial_sum[ncols_dst] = {0.0f};
86
+ for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
87
+ const int ibx = row * blocks_per_row + i;
88
+
89
+ const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
90
+ const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
91
+ const int iby = i * block_type::block_to_q8_1_ratio();
92
+
93
+ #pragma unroll
94
+ for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
95
+ const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
96
+
97
+ #pragma unroll
98
+ for (int j = 0; j < ncols_dst; ++j) {
99
+ const char * vy_j = (const char *)vy + j * stride_col_y_bytes;
100
+ const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1;
101
+ const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2));
102
+
103
+ partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
104
+ }
105
+ }
106
+ }
107
+
108
+ #pragma unroll
109
+ for (int j = 0; j < ncols_dst; ++j) {
110
+ float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>());
111
+
112
+ if (sg.leader()) {
113
+ dst[j * stride_col_dst + row] = sum;
114
+ }
115
+ }
116
+ }
117
+
59
118
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
60
119
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
61
120
  const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
@@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
100
159
  }
101
160
  }
102
161
 
162
+ template <int qk, int qi, typename block_q_t, int vdr,
163
+ vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst>
164
+ static void mul_mat_vec_q_ncols(
165
+ const void * __restrict__ vx,
166
+ const void * __restrict__ vy,
167
+ float * __restrict__ dst,
168
+ const int ncols,
169
+ const int nrows,
170
+ const int stride_col_y,
171
+ const int stride_col_dst,
172
+ const sycl::nd_item<3> & item_ct1) {
173
+
174
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1)
175
+ + item_ct1.get_local_id(1);
176
+
177
+ if (row >= nrows) {
178
+ return;
179
+ }
180
+
181
+ const int blocks_per_row = ncols / qk;
182
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
183
+
184
+ // partial sums: one per output column
185
+ float tmp[ncols_dst] = {0.0f};
186
+
187
+ const block_q_t * x = (const block_q_t *) vx;
188
+ const block_q8_1 * y = (const block_q8_1 *) vy;
189
+
190
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr);
191
+ i < blocks_per_row;
192
+ i += blocks_per_warp) {
193
+
194
+ const int ibx = row * blocks_per_row + i;
195
+ const int iby = i * (qk / QK8_1);
196
+
197
+ // read weight block once, dot against all columns
198
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
199
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
200
+
201
+ #pragma unroll
202
+ for (int j = 0; j < ncols_dst; ++j) {
203
+ tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs);
204
+ }
205
+ }
206
+ }
207
+
208
+ // reduce within subgroup
209
+ #pragma unroll
210
+ for (int j = 0; j < ncols_dst; ++j) {
211
+ #pragma unroll
212
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
213
+ tmp[j] += dpct::permute_sub_group_by_xor(
214
+ item_ct1.get_sub_group(), tmp[j], mask);
215
+ }
216
+ }
217
+
218
+ if (item_ct1.get_local_id(2) == 0) {
219
+ #pragma unroll
220
+ for (int j = 0; j < ncols_dst; ++j) {
221
+ dst[j * stride_col_dst + row] = tmp[j];
222
+ }
223
+ }
224
+ }
225
+
103
226
  template <int qk, int qi, typename block_q_t, int vdr>
104
227
  static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
105
228
  const void *__restrict__ vy,
@@ -537,9 +660,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
537
660
  static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
538
661
  const int nrows, dpct::queue_ptr stream) {
539
662
  GGML_ASSERT(ncols % QK4_0 == 0);
540
- const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
541
- constexpr size_t num_subgroups = 16;
542
- GGML_ASSERT(block_num_y % num_subgroups == 0);
663
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
664
+ constexpr size_t num_subgroups = WARP_SIZE;
665
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
543
666
 
544
667
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
545
668
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
@@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
553
676
  });
554
677
  }
555
678
 
679
+ template <int ncols_dst>
680
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols(
681
+ const void * vx, const void * vy, float * dst,
682
+ const int ncols, const int nrows,
683
+ const int stride_col_y_bytes, const int stride_col_dst,
684
+ dpct::queue_ptr stream) {
685
+ GGML_ASSERT(ncols % QK4_0 == 0);
686
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
687
+ constexpr size_t num_subgroups = 16;
688
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
689
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
690
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
691
+ stream->submit([&](sycl::handler & cgh) {
692
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
693
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
694
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>(
695
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
696
+ });
697
+ });
698
+ }
699
+
700
+ static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
701
+ const void * vx, const void * vy, float * dst,
702
+ const int ncols, const int nrows, const int ncols_dst,
703
+ const int stride_col_y_bytes, const int stride_col_dst,
704
+ dpct::queue_ptr stream) {
705
+ switch (ncols_dst) {
706
+ case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
707
+ case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
708
+ case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
709
+ case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
710
+ case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
711
+ case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
712
+ case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
713
+ case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
714
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst);
715
+ }
716
+ }
717
+
556
718
  static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
557
719
  dpct::queue_ptr stream) {
558
720
  GGML_ASSERT(ncols % QK4_0 == 0);
@@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
571
733
  }
572
734
  }
573
735
 
736
+ template <int ncols_dst>
737
+ static void mul_mat_vec_q4_0_q8_1_sycl_ncols(
738
+ const void * vx, const void * vy, float * dst,
739
+ const int ncols, const int nrows,
740
+ const int stride_col_y, const int stride_col_dst,
741
+ dpct::queue_ptr stream) {
742
+ GGML_ASSERT(ncols % QK4_0 == 0);
743
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
744
+ const sycl::range<3> block_nums(1, 1, block_num_y);
745
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
746
+ stream->submit([&](sycl::handler & cgh) {
747
+ cgh.parallel_for(
748
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
749
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
750
+ mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0,
751
+ VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>(
752
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
753
+ });
754
+ });
755
+ }
756
+
757
+ static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
758
+ const void * vx, const void * vy, float * dst,
759
+ const int ncols, const int nrows, const int ncols_dst,
760
+ const int stride_col_y, const int stride_col_dst,
761
+ dpct::queue_ptr stream) {
762
+ switch (ncols_dst) {
763
+ case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
764
+ case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
765
+ case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
766
+ case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
767
+ case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
768
+ case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
769
+ case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
770
+ case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
771
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst);
772
+ }
773
+ }
774
+
574
775
  static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
575
776
  float *dst, const int ncols,
576
777
  const int nrows,
@@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
595
796
  }
596
797
  }
597
798
 
799
+ template <int ncols_dst>
800
+ static void mul_mat_vec_q4_1_q8_1_sycl_ncols(
801
+ const void * vx, const void * vy, float * dst,
802
+ const int ncols, const int nrows,
803
+ const int stride_col_y, const int stride_col_dst,
804
+ dpct::queue_ptr stream) {
805
+ GGML_ASSERT(ncols % QK4_1 == 0);
806
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
807
+ const sycl::range<3> block_nums(1, 1, block_num_y);
808
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
809
+ stream->submit([&](sycl::handler & cgh) {
810
+ cgh.parallel_for(
811
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
812
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
813
+ mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1,
814
+ VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>(
815
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
816
+ });
817
+ });
818
+ }
819
+
820
+ static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
821
+ const void * vx, const void * vy, float * dst,
822
+ const int ncols, const int nrows, const int ncols_dst,
823
+ const int stride_col_y, const int stride_col_dst,
824
+ dpct::queue_ptr stream) {
825
+ switch (ncols_dst) {
826
+ case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
827
+ case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
828
+ case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
829
+ case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
830
+ case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
831
+ case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
832
+ case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
833
+ case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
834
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst);
835
+ }
836
+ }
837
+
598
838
  static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
599
839
  dpct::queue_ptr stream) {
600
840
  GGML_ASSERT(ncols % QK_MXFP4 == 0);
@@ -613,6 +853,101 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
613
853
  }
614
854
  }
615
855
 
856
+ template <int ncols_dst>
857
+ static void mul_mat_vec_mxfp4_q8_1_sycl_ncols(
858
+ const void * vx, const void * vy, float * dst,
859
+ const int ncols, const int nrows,
860
+ const int stride_col_y, const int stride_col_dst,
861
+ dpct::queue_ptr stream) {
862
+ GGML_ASSERT(ncols % QK_MXFP4 == 0);
863
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
864
+ const sycl::range<3> block_nums(1, 1, block_num_y);
865
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
866
+ stream->submit([&](sycl::handler & cgh) {
867
+ cgh.parallel_for(
868
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
869
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
870
+ mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4,
871
+ VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>(
872
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
873
+ });
874
+ });
875
+ }
876
+
877
+ static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
878
+ const void * vx, const void * vy, float * dst,
879
+ const int ncols, const int nrows, const int ncols_dst,
880
+ const int stride_col_y, const int stride_col_dst,
881
+ dpct::queue_ptr stream) {
882
+ switch (ncols_dst) {
883
+ case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
884
+ case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
885
+ case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
886
+ case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
887
+ case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
888
+ case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
889
+ case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
890
+ case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
891
+ default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst);
892
+ }
893
+ }
894
+
895
+ static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
896
+ dpct::queue_ptr stream) {
897
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
898
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
899
+ const sycl::range<3> block_nums(1, 1, block_num_y);
900
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
901
+
902
+ {
903
+ stream->submit([&](sycl::handler & cgh) {
904
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
+ mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
907
+ vx, vy, dst, ncols, nrows, item_ct1);
908
+ });
909
+ });
910
+ }
911
+ }
912
+
913
+ template <int ncols_dst>
914
+ static void mul_mat_vec_nvfp4_q8_1_sycl_ncols(
915
+ const void * vx, const void * vy, float * dst,
916
+ const int ncols, const int nrows,
917
+ const int stride_col_y, const int stride_col_dst,
918
+ dpct::queue_ptr stream) {
919
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
920
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
921
+ const sycl::range<3> block_nums(1, 1, block_num_y);
922
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
923
+ stream->submit([&](sycl::handler & cgh) {
924
+ cgh.parallel_for(
925
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
+ mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4,
928
+ VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>(
929
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
930
+ });
931
+ });
932
+ }
933
+
934
+ static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
935
+ const void * vx, const void * vy, float * dst,
936
+ const int ncols, const int nrows, const int ncols_dst,
937
+ const int stride_col_y, const int stride_col_dst,
938
+ dpct::queue_ptr stream) {
939
+ switch (ncols_dst) {
940
+ case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
941
+ case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
942
+ case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
943
+ case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
944
+ case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
945
+ case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
946
+ case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
947
+ case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
948
+ default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst);
949
+ }
950
+ }
616
951
 
617
952
  static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
618
953
  float *dst, const int ncols,
@@ -638,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
638
973
  }
639
974
  }
640
975
 
976
+ template <int ncols_dst>
977
+ static void mul_mat_vec_q5_0_q8_1_sycl_ncols(
978
+ const void * vx, const void * vy, float * dst,
979
+ const int ncols, const int nrows,
980
+ const int stride_col_y, const int stride_col_dst,
981
+ dpct::queue_ptr stream) {
982
+ GGML_ASSERT(ncols % QK5_0 == 0);
983
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
984
+ const sycl::range<3> block_nums(1, 1, block_num_y);
985
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
986
+ stream->submit([&](sycl::handler & cgh) {
987
+ cgh.parallel_for(
988
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
+ mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0,
991
+ VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>(
992
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
993
+ });
994
+ });
995
+ }
996
+
997
+ static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
998
+ const void * vx, const void * vy, float * dst,
999
+ const int ncols, const int nrows, const int ncols_dst,
1000
+ const int stride_col_y, const int stride_col_dst,
1001
+ dpct::queue_ptr stream) {
1002
+ switch (ncols_dst) {
1003
+ case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1004
+ case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1005
+ case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1006
+ case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1007
+ case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1008
+ case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1009
+ case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1010
+ case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1011
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst);
1012
+ }
1013
+ }
1014
+
641
1015
  static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
642
1016
  float *dst, const int ncols,
643
1017
  const int nrows,
@@ -662,6 +1036,103 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
662
1036
  }
663
1037
  }
664
1038
 
1039
+ template <int ncols_dst>
1040
+ static void mul_mat_vec_q5_1_q8_1_sycl_ncols(
1041
+ const void * vx, const void * vy, float * dst,
1042
+ const int ncols, const int nrows,
1043
+ const int stride_col_y, const int stride_col_dst,
1044
+ dpct::queue_ptr stream) {
1045
+ GGML_ASSERT(ncols % QK5_1 == 0);
1046
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1047
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1048
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1049
+ stream->submit([&](sycl::handler & cgh) {
1050
+ cgh.parallel_for(
1051
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1052
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1053
+ mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1,
1054
+ VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>(
1055
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1056
+ });
1057
+ });
1058
+ }
1059
+
1060
+ static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
1061
+ const void * vx, const void * vy, float * dst,
1062
+ const int ncols, const int nrows, const int ncols_dst,
1063
+ const int stride_col_y, const int stride_col_dst,
1064
+ dpct::queue_ptr stream) {
1065
+ switch (ncols_dst) {
1066
+ case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1067
+ case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1068
+ case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1069
+ case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1070
+ case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1071
+ case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1072
+ case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1073
+ case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1074
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst);
1075
+ }
1076
+ }
1077
+
1078
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1079
+ const int nrows, dpct::queue_ptr stream) {
1080
+ GGML_ASSERT(ncols % QK8_0 == 0);
1081
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1082
+ constexpr size_t num_subgroups = WARP_SIZE;
1083
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1084
+
1085
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
1086
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1087
+
1088
+ stream->submit([&](sycl::handler & cgh) {
1089
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1090
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1091
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
1092
+ nd_item);
1093
+ });
1094
+ });
1095
+ }
1096
+
1097
+ template <int ncols_dst>
1098
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols(
1099
+ const void * vx, const void * vy, float * dst,
1100
+ const int ncols, const int nrows,
1101
+ const int stride_col_y_bytes, const int stride_col_dst,
1102
+ dpct::queue_ptr stream) {
1103
+ GGML_ASSERT(ncols % QK8_0 == 0);
1104
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1105
+ constexpr size_t num_subgroups = 16;
1106
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1107
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1108
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1109
+ stream->submit([&](sycl::handler & cgh) {
1110
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1111
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1112
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>(
1113
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1114
+ });
1115
+ });
1116
+ }
1117
+
1118
+ static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
1119
+ const void * vx, const void * vy, float * dst,
1120
+ const int ncols, const int nrows, const int ncols_dst,
1121
+ const int stride_col_y_bytes, const int stride_col_dst,
1122
+ dpct::queue_ptr stream) {
1123
+ switch (ncols_dst) {
1124
+ case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1125
+ case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1126
+ case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1127
+ case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1128
+ case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1129
+ case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1130
+ case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1131
+ case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1132
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst);
1133
+ }
1134
+ }
1135
+
665
1136
  static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
666
1137
  float *dst, const int ncols,
667
1138
  const int nrows,
@@ -686,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
686
1157
  }
687
1158
  }
688
1159
 
1160
+ template <int ncols_dst>
1161
+ static void mul_mat_vec_q8_0_q8_1_sycl_ncols(
1162
+ const void * vx, const void * vy, float * dst,
1163
+ const int ncols, const int nrows,
1164
+ const int stride_col_y, const int stride_col_dst,
1165
+ dpct::queue_ptr stream) {
1166
+ GGML_ASSERT(ncols % QK8_0 == 0);
1167
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1168
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1169
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1170
+ stream->submit([&](sycl::handler & cgh) {
1171
+ cgh.parallel_for(
1172
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1173
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1174
+ mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0,
1175
+ VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>(
1176
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1177
+ });
1178
+ });
1179
+ }
1180
+
1181
+ static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
1182
+ const void * vx, const void * vy, float * dst,
1183
+ const int ncols, const int nrows, const int ncols_dst,
1184
+ const int stride_col_y, const int stride_col_dst,
1185
+ dpct::queue_ptr stream) {
1186
+ switch (ncols_dst) {
1187
+ case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1188
+ case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1189
+ case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1190
+ case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1191
+ case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1192
+ case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1193
+ case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1194
+ case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1195
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst);
1196
+ }
1197
+ }
1198
+
689
1199
  static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
690
1200
  float *dst, const int ncols,
691
1201
  const int nrows,
@@ -710,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
710
1220
  }
711
1221
  }
712
1222
 
1223
+ template <int ncols_dst>
1224
+ static void mul_mat_vec_q2_K_q8_1_sycl_ncols(
1225
+ const void * vx, const void * vy, float * dst,
1226
+ const int ncols, const int nrows,
1227
+ const int stride_col_y, const int stride_col_dst,
1228
+ dpct::queue_ptr stream) {
1229
+ GGML_ASSERT(ncols % QK_K == 0);
1230
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1231
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1232
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1233
+ stream->submit([&](sycl::handler & cgh) {
1234
+ cgh.parallel_for(
1235
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1236
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1237
+ mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K,
1238
+ VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>(
1239
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1240
+ });
1241
+ });
1242
+ }
1243
+
1244
+ static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
1245
+ const void * vx, const void * vy, float * dst,
1246
+ const int ncols, const int nrows, const int ncols_dst,
1247
+ const int stride_col_y, const int stride_col_dst,
1248
+ dpct::queue_ptr stream) {
1249
+ switch (ncols_dst) {
1250
+ case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1251
+ case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1252
+ case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1253
+ case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1254
+ case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1255
+ case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1256
+ case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1257
+ case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1258
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst);
1259
+ }
1260
+ }
1261
+
713
1262
  static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
714
1263
  float *dst, const int ncols,
715
1264
  const int nrows,
@@ -734,6 +1283,105 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
734
1283
  }
735
1284
  }
736
1285
 
1286
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1287
+ const int nrows, dpct::queue_ptr stream) {
1288
+ GGML_ASSERT(ncols % QK_K == 0);
1289
+
1290
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1291
+ constexpr size_t num_subgroups = WARP_SIZE;
1292
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1293
+
1294
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1295
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1296
+
1297
+ stream->submit([&](sycl::handler & cgh) {
1298
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1299
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1300
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows,
1301
+ nd_item);
1302
+ });
1303
+ });
1304
+ }
1305
+
1306
+ template <int ncols_dst>
1307
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols(
1308
+ const void * vx, const void * vy, float * dst,
1309
+ const int ncols, const int nrows,
1310
+ const int stride_col_y_bytes, const int stride_col_dst,
1311
+ dpct::queue_ptr stream) {
1312
+ GGML_ASSERT(ncols % QK_K == 0);
1313
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1314
+ constexpr size_t num_subgroups = 16;
1315
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1316
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1317
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1318
+ stream->submit([&](sycl::handler & cgh) {
1319
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1320
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1321
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>(
1322
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1323
+ });
1324
+ });
1325
+ }
1326
+
1327
+ static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
1328
+ const void * vx, const void * vy, float * dst,
1329
+ const int ncols, const int nrows, const int ncols_dst,
1330
+ const int stride_col_y_bytes, const int stride_col_dst,
1331
+ dpct::queue_ptr stream) {
1332
+ switch (ncols_dst) {
1333
+ case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1334
+ case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1335
+ case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1336
+ case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1337
+ case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1338
+ case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1339
+ case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1340
+ case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1341
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst);
1342
+ }
1343
+ }
1344
+
1345
+ template <int ncols_dst>
1346
+ static void mul_mat_vec_q3_K_q8_1_sycl_ncols(
1347
+ const void * vx, const void * vy, float * dst,
1348
+ const int ncols, const int nrows,
1349
+ const int stride_col_y, const int stride_col_dst,
1350
+ dpct::queue_ptr stream) {
1351
+ GGML_ASSERT(ncols % QK_K == 0);
1352
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1353
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1354
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1355
+ stream->submit([&](sycl::handler & cgh) {
1356
+ cgh.parallel_for(
1357
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1358
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1359
+ mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K,
1360
+ VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>(
1361
+ vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
1362
+ });
1363
+ });
1364
+ }
1365
+
1366
+ static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
1367
+ const void * vx, const void * vy, float * dst,
1368
+ const int ncols, const int nrows, const int ncols_dst,
1369
+ const int stride_col_y, const int stride_col_dst,
1370
+ dpct::queue_ptr stream) {
1371
+ switch (ncols_dst) {
1372
+ case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1373
+ case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1374
+ case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1375
+ case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1376
+ case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1377
+ case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1378
+ case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1379
+ case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1380
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst);
1381
+ }
1382
+ }
1383
+
1384
+
737
1385
  static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
738
1386
  float *dst, const int ncols,
739
1387
  const int nrows,
@@ -758,13 +1406,58 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
758
1406
  }
759
1407
  }
760
1408
 
1409
+ template <int ncols_dst>
1410
+ static void mul_mat_vec_q4_K_q8_1_sycl_ncols(
1411
+ const void * vx, const void * vy, float * dst,
1412
+ const int ncols, const int nrows,
1413
+ const int stride_col_y, const int stride_col_dst,
1414
+ dpct::queue_ptr stream) {
1415
+ GGML_ASSERT(ncols % QK_K == 0);
1416
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1417
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1418
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1419
+
1420
+ stream->submit([&](sycl::handler & cgh) {
1421
+ cgh.parallel_for(
1422
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1423
+ [=](sycl::nd_item<3> item_ct1)
1424
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1425
+ mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K,
1426
+ VDR_Q4_K_Q8_1_MMVQ,
1427
+ vec_dot_q4_K_q8_1,
1428
+ ncols_dst>(
1429
+ vx, vy, dst, ncols, nrows,
1430
+ stride_col_y, stride_col_dst, item_ct1);
1431
+ });
1432
+ });
1433
+ }
1434
+
1435
+ static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
1436
+ const void * vx, const void * vy, float * dst,
1437
+ const int ncols, const int nrows,
1438
+ const int ncols_dst,
1439
+ const int stride_col_y, const int stride_col_dst,
1440
+ dpct::queue_ptr stream) {
1441
+ switch (ncols_dst) {
1442
+ case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1443
+ case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1444
+ case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1445
+ case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1446
+ case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1447
+ case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1448
+ case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1449
+ case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1450
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst);
1451
+ }
1452
+ }
1453
+
761
1454
  static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
762
1455
  const int nrows, dpct::queue_ptr stream) {
763
1456
  GGML_ASSERT(ncols % QK_K == 0);
764
1457
 
765
- const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
766
- constexpr size_t num_subgroups = 16;
767
- GGML_ASSERT(block_num_y % num_subgroups == 0);
1458
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1459
+ constexpr size_t num_subgroups = WARP_SIZE;
1460
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
768
1461
 
769
1462
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
770
1463
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
@@ -778,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
778
1471
  });
779
1472
  }
780
1473
 
1474
+ template <int ncols_dst>
1475
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols(
1476
+ const void * vx, const void * vy, float * dst,
1477
+ const int ncols, const int nrows,
1478
+ const int stride_col_y_bytes, const int stride_col_dst,
1479
+ dpct::queue_ptr stream) {
1480
+ GGML_ASSERT(ncols % QK_K == 0);
1481
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1482
+ constexpr size_t num_subgroups = 16;
1483
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1484
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1485
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1486
+ stream->submit([&](sycl::handler & cgh) {
1487
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1488
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1489
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>(
1490
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1491
+ });
1492
+ });
1493
+ }
1494
+
1495
+ static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
1496
+ const void * vx, const void * vy, float * dst,
1497
+ const int ncols, const int nrows, const int ncols_dst,
1498
+ const int stride_col_y_bytes, const int stride_col_dst,
1499
+ dpct::queue_ptr stream) {
1500
+ switch (ncols_dst) {
1501
+ case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1502
+ case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1503
+ case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1504
+ case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1505
+ case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1506
+ case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1507
+ case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1508
+ case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1509
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst);
1510
+ }
1511
+ }
781
1512
 
782
1513
  static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
783
1514
  float *dst, const int ncols,
@@ -803,9 +1534,55 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
803
1534
  }
804
1535
  }
805
1536
 
806
- static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1537
+ template <int ncols_dst>
1538
+ static void mul_mat_vec_q5_K_q8_1_sycl_ncols(
1539
+ const void * vx, const void * vy, float * dst,
1540
+ const int ncols, const int nrows,
1541
+ const int stride_col_y, const int stride_col_dst,
1542
+ dpct::queue_ptr stream) {
1543
+ GGML_ASSERT(ncols % QK_K == 0);
1544
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1545
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1546
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1547
+
1548
+ stream->submit([&](sycl::handler & cgh) {
1549
+ cgh.parallel_for(
1550
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1551
+ [=](sycl::nd_item<3> item_ct1)
1552
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1553
+ mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K,
1554
+ VDR_Q5_K_Q8_1_MMVQ,
1555
+ vec_dot_q5_K_q8_1,
1556
+ ncols_dst>(
1557
+ vx, vy, dst, ncols, nrows,
1558
+ stride_col_y, stride_col_dst, item_ct1);
1559
+ });
1560
+ });
1561
+ }
1562
+
1563
+ static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
1564
+ const void * vx, const void * vy, float * dst,
1565
+ const int ncols, const int nrows,
1566
+ const int ncols_dst,
1567
+ const int stride_col_y, const int stride_col_dst,
1568
+ dpct::queue_ptr stream) {
1569
+ switch (ncols_dst) {
1570
+ case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1571
+ case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1572
+ case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1573
+ case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1574
+ case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1575
+ case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1576
+ case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1577
+ case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1578
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst);
1579
+ }
1580
+ }
1581
+
1582
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
807
1583
  const int nrows, dpct::queue_ptr stream) {
808
1584
  GGML_ASSERT(ncols % QK_K == 0);
1585
+
809
1586
  const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
810
1587
  constexpr size_t num_subgroups = 16;
811
1588
  GGML_ASSERT(block_num_y % num_subgroups == 0);
@@ -813,6 +1590,64 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
813
1590
  const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
814
1591
  const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
815
1592
 
1593
+ stream->submit([&](sycl::handler & cgh) {
1594
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1595
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1596
+ mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols,
1597
+ nrows, nd_item);
1598
+ });
1599
+ });
1600
+ }
1601
+
1602
+ template <int ncols_dst>
1603
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols(
1604
+ const void * vx, const void * vy, float * dst,
1605
+ const int ncols, const int nrows,
1606
+ const int stride_col_y_bytes, const int stride_col_dst,
1607
+ dpct::queue_ptr stream) {
1608
+ GGML_ASSERT(ncols % QK_K == 0);
1609
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1610
+ constexpr size_t num_subgroups = 16;
1611
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1612
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1613
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1614
+ stream->submit([&](sycl::handler & cgh) {
1615
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1616
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1617
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>(
1618
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1619
+ });
1620
+ });
1621
+ }
1622
+
1623
+ static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
1624
+ const void * vx, const void * vy, float * dst,
1625
+ const int ncols, const int nrows, const int ncols_dst,
1626
+ const int stride_col_y_bytes, const int stride_col_dst,
1627
+ dpct::queue_ptr stream) {
1628
+ switch (ncols_dst) {
1629
+ case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1630
+ case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1631
+ case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1632
+ case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1633
+ case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1634
+ case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1635
+ case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1636
+ case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1637
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst);
1638
+ }
1639
+ }
1640
+
1641
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
1642
+ const int nrows, dpct::queue_ptr stream) {
1643
+ GGML_ASSERT(ncols % QK_K == 0);
1644
+ // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
1645
+ constexpr size_t num_subgroups = WARP_SIZE;
1646
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
1647
+
1648
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1649
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1650
+
816
1651
  stream->submit([&](sycl::handler & cgh) {
817
1652
  cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
818
1653
  [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
@@ -821,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
821
1656
  });
822
1657
  });
823
1658
  }
1659
+
1660
+ template <int ncols_dst>
1661
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols(
1662
+ const void * vx, const void * vy, float * dst,
1663
+ const int ncols, const int nrows,
1664
+ const int stride_col_y_bytes, const int stride_col_dst,
1665
+ dpct::queue_ptr stream) {
1666
+ GGML_ASSERT(ncols % QK_K == 0);
1667
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
1668
+ constexpr size_t num_subgroups = 16;
1669
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
1670
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
1671
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
1672
+ stream->submit([&](sycl::handler & cgh) {
1673
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
1674
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1675
+ mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>(
1676
+ vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item);
1677
+ });
1678
+ });
1679
+ }
1680
+
1681
+ static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
1682
+ const void * vx, const void * vy, float * dst,
1683
+ const int ncols, const int nrows, const int ncols_dst,
1684
+ const int stride_col_y_bytes, const int stride_col_dst,
1685
+ dpct::queue_ptr stream) {
1686
+ switch (ncols_dst) {
1687
+ case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1688
+ case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1689
+ case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1690
+ case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1691
+ case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1692
+ case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1693
+ case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1694
+ case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break;
1695
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst);
1696
+ }
1697
+ }
1698
+
824
1699
  static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
825
1700
  float *dst, const int ncols,
826
1701
  const int nrows,
@@ -845,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
845
1720
  }
846
1721
  }
847
1722
 
1723
+ template <int ncols_dst>
1724
+ static void mul_mat_vec_q6_K_q8_1_sycl_ncols(
1725
+ const void * vx, const void * vy, float * dst,
1726
+ const int ncols, const int nrows,
1727
+ const int stride_col_y, const int stride_col_dst,
1728
+ dpct::queue_ptr stream) {
1729
+ GGML_ASSERT(ncols % QK_K == 0);
1730
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1731
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1732
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1733
+
1734
+ stream->submit([&](sycl::handler & cgh) {
1735
+ cgh.parallel_for(
1736
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1737
+ [=](sycl::nd_item<3> item_ct1)
1738
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1739
+ mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K,
1740
+ VDR_Q6_K_Q8_1_MMVQ,
1741
+ vec_dot_q6_K_q8_1,
1742
+ ncols_dst>(
1743
+ vx, vy, dst, ncols, nrows,
1744
+ stride_col_y, stride_col_dst, item_ct1);
1745
+ });
1746
+ });
1747
+ }
1748
+
1749
+ static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
1750
+ const void * vx, const void * vy, float * dst,
1751
+ const int ncols, const int nrows,
1752
+ const int ncols_dst,
1753
+ const int stride_col_y, const int stride_col_dst,
1754
+ dpct::queue_ptr stream) {
1755
+ switch (ncols_dst) {
1756
+ case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1757
+ case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1758
+ case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1759
+ case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1760
+ case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1761
+ case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1762
+ case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1763
+ case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1764
+ default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst);
1765
+ }
1766
+ }
1767
+
848
1768
 
849
1769
  static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
850
1770
  float *dst, const int ncols,
@@ -1041,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
1041
1961
  }
1042
1962
  }
1043
1963
 
1964
+ template <int ncols_dst>
1965
+ static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols(
1966
+ const void * vx, const void * vy, float * dst,
1967
+ const int ncols, const int nrows,
1968
+ const int stride_col_y, const int stride_col_dst,
1969
+ dpct::queue_ptr stream) {
1970
+ GGML_ASSERT(ncols % QK_K == 0);
1971
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
1972
+ const sycl::range<3> block_nums(1, 1, block_num_y);
1973
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
1974
+
1975
+ stream->submit([&](sycl::handler & cgh) {
1976
+ cgh.parallel_for(
1977
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
1978
+ [=](sycl::nd_item<3> item_ct1)
1979
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1980
+ mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs,
1981
+ 1,
1982
+ vec_dot_iq4_xs_q8_1,
1983
+ ncols_dst>(
1984
+ vx, vy, dst, ncols, nrows,
1985
+ stride_col_y, stride_col_dst, item_ct1);
1986
+ });
1987
+ });
1988
+ }
1989
+
1990
+ static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
1991
+ const void * vx, const void * vy, float * dst,
1992
+ const int ncols, const int nrows,
1993
+ const int ncols_dst,
1994
+ const int stride_col_y, const int stride_col_dst,
1995
+ dpct::queue_ptr stream) {
1996
+ switch (ncols_dst) {
1997
+ case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
1998
+ case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
1999
+ case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2000
+ case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2001
+ case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2002
+ case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2003
+ case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2004
+ case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
2005
+ default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst);
2006
+ }
2007
+ }
2008
+
1044
2009
  void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1045
2010
  ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1046
2011
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
@@ -1067,50 +2032,219 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1067
2032
  case GGML_TYPE_Q4_0:
1068
2033
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1069
2034
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1070
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
1071
- reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1072
- } else {
2035
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2036
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2037
+ const int stride_col_dst = dst->ne[0];
2038
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2039
+ reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
2040
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2041
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2042
+ return;
2043
+ } else {
2044
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
2045
+ reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2046
+ }
2047
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2048
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2049
+ const int stride_col_dst = dst->ne[0];
2050
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2051
+ mul_mat_vec_q4_0_q8_1_sycl_switch_ncols(
2052
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2053
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2054
+ return;
2055
+ } else if (i == 0 || src1_ncols == 1) {
1073
2056
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
1074
2057
  mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1075
2058
  }
1076
2059
  break;
1077
2060
  case GGML_TYPE_Q4_1:
1078
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2061
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2062
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2063
+ const int stride_col_dst = dst->ne[0];
2064
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2065
+ mul_mat_vec_q4_1_q8_1_sycl_switch_ncols(
2066
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2067
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2068
+ return;
2069
+ } else if (i == 0 || src1_ncols == 1) {
2070
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2071
+ }
1079
2072
  break;
1080
2073
  case GGML_TYPE_Q5_0:
1081
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2074
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2075
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2076
+ const int stride_col_dst = dst->ne[0];
2077
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2078
+ mul_mat_vec_q5_0_q8_1_sycl_switch_ncols(
2079
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2080
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2081
+ return;
2082
+ } else if (i == 0 || src1_ncols == 1) {
2083
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2084
+ }
1082
2085
  break;
1083
2086
  case GGML_TYPE_Q5_1:
1084
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2087
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2088
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2089
+ const int stride_col_dst = dst->ne[0];
2090
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2091
+ mul_mat_vec_q5_1_q8_1_sycl_switch_ncols(
2092
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2093
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2094
+ return;
2095
+ } else if (i == 0 || src1_ncols == 1) {
2096
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2097
+ }
1085
2098
  break;
1086
2099
  case GGML_TYPE_Q8_0:
1087
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2100
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2101
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2102
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2103
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2104
+ const int stride_col_dst = dst->ne[0];
2105
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2106
+ reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
2107
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2108
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2109
+ return;
2110
+ } else {
2111
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
2112
+ reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2113
+ }
2114
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2115
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2116
+ const int stride_col_dst = dst->ne[0];
2117
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2118
+ mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
2119
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2120
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2121
+ return;
2122
+ } else if (i == 0 || src1_ncols == 1) {
2123
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n");
2124
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2125
+ }
1088
2126
  break;
1089
2127
  case GGML_TYPE_Q2_K:
1090
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2128
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2129
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2130
+ const int stride_col_dst = dst->ne[0];
2131
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2132
+ mul_mat_vec_q2_K_q8_1_sycl_switch_ncols(
2133
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2134
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2135
+ return;
2136
+ } else if (i == 0 || src1_ncols == 1) {
2137
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2138
+ }
1091
2139
  break;
1092
2140
  case GGML_TYPE_Q3_K:
1093
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2141
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2142
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2143
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2144
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2145
+ const int stride_col_dst = dst->ne[0];
2146
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2147
+ reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols(
2148
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2149
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2150
+ return;
2151
+ } else {
2152
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n");
2153
+ reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2154
+ }
2155
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2156
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2157
+ const int stride_col_dst = dst->ne[0];
2158
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2159
+ mul_mat_vec_q3_K_q8_1_sycl_switch_ncols(
2160
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2161
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2162
+ return;
2163
+ } else if (i == 0 || src1_ncols == 1) {
2164
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n");
2165
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2166
+ }
1094
2167
  break;
1095
2168
  case GGML_TYPE_Q4_K:
1096
2169
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1097
2170
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1098
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
1099
- reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1100
- } else {
2171
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2172
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2173
+ const int stride_col_dst = dst->ne[0];
2174
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2175
+ reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols(
2176
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2177
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2178
+ return;
2179
+ } else {
2180
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n");
2181
+ reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2182
+ }
2183
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2184
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2185
+ const int stride_col_dst = dst->ne[0];
2186
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2187
+ mul_mat_vec_q4_K_q8_1_sycl_switch_ncols(
2188
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2189
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2190
+ return;
2191
+ } else if (i == 0 || src1_ncols == 1) {
1101
2192
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n");
1102
2193
  mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1103
2194
  }
1104
2195
  break;
1105
2196
  case GGML_TYPE_Q5_K:
1106
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2197
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
2198
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
2199
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2200
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2201
+ const int stride_col_dst = dst->ne[0];
2202
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2203
+ reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols(
2204
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2205
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2206
+ return;
2207
+ } else {
2208
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n");
2209
+ reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2210
+ }
2211
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2212
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2213
+ const int stride_col_dst = dst->ne[0];
2214
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2215
+ mul_mat_vec_q5_K_q8_1_sycl_switch_ncols(
2216
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2217
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2218
+ return;
2219
+ } else if (i == 0 || src1_ncols == 1) {
2220
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n");
2221
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2222
+ }
1107
2223
  break;
1108
2224
  case GGML_TYPE_Q6_K:
1109
2225
  if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1110
2226
  ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1111
- GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
1112
- reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1113
- } else {
2227
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2228
+ const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs;
2229
+ const int stride_col_dst = dst->ne[0];
2230
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2231
+ reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols(
2232
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2233
+ src1_ncols, stride_col_y_bytes, stride_col_dst, stream);
2234
+ return;
2235
+ } else {
2236
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
2237
+ reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2238
+ }
2239
+ } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2240
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2241
+ const int stride_col_dst = dst->ne[0];
2242
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2243
+ mul_mat_vec_q6_K_q8_1_sycl_switch_ncols(
2244
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2245
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2246
+ return;
2247
+ } else if (i == 0 || src1_ncols == 1) {
1114
2248
  GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
1115
2249
  mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1116
2250
  }
@@ -1140,13 +2274,46 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1140
2274
  mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1141
2275
  break;
1142
2276
  case GGML_TYPE_IQ4_XS:
1143
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2277
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2278
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2279
+ const int stride_col_dst = dst->ne[0];
2280
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2281
+ mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols(
2282
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2283
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2284
+ return;
2285
+ } else if (i == 0 || src1_ncols == 1) {
2286
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2287
+ }
1144
2288
  break;
1145
2289
  case GGML_TYPE_MXFP4:
1146
- mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2290
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2291
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2292
+ const int stride_col_dst = dst->ne[0];
2293
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2294
+ mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols(
2295
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2296
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2297
+ return;
2298
+ } else if (i == 0 || src1_ncols == 1) {
2299
+ mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2300
+ }
2301
+ break;
2302
+ case GGML_TYPE_NVFP4:
2303
+ if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
2304
+ const int stride_col_y = src1_padded_col_size / QK8_1;
2305
+ const int stride_col_dst = dst->ne[0];
2306
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
2307
+ mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols(
2308
+ src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
2309
+ src1_ncols, stride_col_y, stride_col_dst, stream);
2310
+ return;
2311
+ } else if (i == 0 || src1_ncols == 1) {
2312
+ mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
2313
+ }
1147
2314
  break;
1148
2315
  default:
1149
- GGML_ABORT("fatal error");
2316
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
1150
2317
  }
1151
2318
  }
1152
2319
  GGML_UNUSED(src1);
@@ -1154,3 +2321,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
1154
2321
  GGML_UNUSED(src1_ddf_i);
1155
2322
  GGML_UNUSED(ctx);
1156
2323
  }
2324
+
2325
+ // src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj).
2326
+ template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
2327
+ static void mul_mat_vec_q_moe(
2328
+ const void * __restrict__ vx_base, const void * __restrict__ vy_base,
2329
+ float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev,
2330
+ const int ncols, const int nrows,
2331
+ const size_t expert_weight_stride, const size_t dst_row_stride,
2332
+ const size_t src1_row_stride,
2333
+ const sycl::nd_item<3> & item_ct1) {
2334
+
2335
+ const int expert_idx = item_ct1.get_group(1);
2336
+ const int i02 = ids_dev[expert_idx];
2337
+
2338
+ const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride;
2339
+ const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride;
2340
+ float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride);
2341
+
2342
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2343
+
2344
+ if (row >= nrows) {
2345
+ return;
2346
+ }
2347
+
2348
+ const int blocks_per_row = ncols / qk;
2349
+ constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
2350
+
2351
+ float tmp = 0.0f;
2352
+
2353
+ const block_q_t * x = (const block_q_t *) vx;
2354
+ const block_q8_1 * y = (const block_q8_1 *) vy;
2355
+
2356
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
2357
+ const int ibx = row * blocks_per_row + i;
2358
+ const int iby = i * (qk / QK8_1);
2359
+
2360
+ for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
2361
+ const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
2362
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
2363
+ }
2364
+ }
2365
+
2366
+ #pragma unroll
2367
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
2368
+ tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
2369
+ }
2370
+
2371
+ if (item_ct1.get_local_id(2) == 0) {
2372
+ dst[row] = tmp;
2373
+ }
2374
+ }
2375
+
2376
+ template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
2377
+ static void launch_mul_mat_vec_q_moe(
2378
+ const void * vx_base, const void * vy, const int32_t * ids_dev,
2379
+ float * dst_base, const int ncols, const int nrows, const int n_experts_used,
2380
+ const size_t expert_weight_stride, const size_t dst_row_stride,
2381
+ const size_t src1_row_stride,
2382
+ dpct::queue_ptr stream) {
2383
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
2384
+ const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y);
2385
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
2386
+ stream->submit([&](sycl::handler & cgh) {
2387
+ cgh.parallel_for(
2388
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2389
+ [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2390
+ mul_mat_vec_q_moe<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
2391
+ vx_base, vy, dst_base, ids_dev, ncols, nrows,
2392
+ expert_weight_stride, dst_row_stride, src1_row_stride, item);
2393
+ });
2394
+ });
2395
+ }
2396
+
2397
+ bool ggml_sycl_mul_mat_vec_q_id(
2398
+ enum ggml_type src0_type,
2399
+ const void * vx_base,
2400
+ const void * vy,
2401
+ const int32_t * ids_dev,
2402
+ float * dst_base,
2403
+ int ncols,
2404
+ int nrows,
2405
+ int n_experts_used,
2406
+ size_t expert_weight_stride,
2407
+ size_t dst_row_stride,
2408
+ size_t src1_row_stride,
2409
+ dpct::queue_ptr stream) {
2410
+ switch (src0_type) {
2411
+ case GGML_TYPE_Q4_0:
2412
+ launch_mul_mat_vec_q_moe<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
2413
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2414
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2415
+ return true;
2416
+ case GGML_TYPE_Q4_1:
2417
+ launch_mul_mat_vec_q_moe<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
2418
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2419
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2420
+ return true;
2421
+ case GGML_TYPE_Q5_0:
2422
+ launch_mul_mat_vec_q_moe<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
2423
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2424
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2425
+ return true;
2426
+ case GGML_TYPE_Q5_1:
2427
+ launch_mul_mat_vec_q_moe<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
2428
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2429
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2430
+ return true;
2431
+ case GGML_TYPE_Q8_0:
2432
+ launch_mul_mat_vec_q_moe<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
2433
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2434
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2435
+ return true;
2436
+ case GGML_TYPE_Q2_K:
2437
+ launch_mul_mat_vec_q_moe<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
2438
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2439
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2440
+ return true;
2441
+ case GGML_TYPE_Q3_K:
2442
+ launch_mul_mat_vec_q_moe<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
2443
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2444
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2445
+ return true;
2446
+ case GGML_TYPE_Q4_K:
2447
+ launch_mul_mat_vec_q_moe<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
2448
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2449
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2450
+ return true;
2451
+ case GGML_TYPE_Q5_K:
2452
+ launch_mul_mat_vec_q_moe<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
2453
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2454
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2455
+ return true;
2456
+ case GGML_TYPE_Q6_K:
2457
+ launch_mul_mat_vec_q_moe<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
2458
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2459
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2460
+ return true;
2461
+ case GGML_TYPE_MXFP4:
2462
+ launch_mul_mat_vec_q_moe<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
2463
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2464
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2465
+ return true;
2466
+ case GGML_TYPE_NVFP4:
2467
+ launch_mul_mat_vec_q_moe<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
2468
+ vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
2469
+ expert_weight_stride, dst_row_stride, src1_row_stride, stream);
2470
+ return true;
2471
+ default:
2472
+ return false;
2473
+ }
2474
+ }