whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -1,6 +1,7 @@
1
1
  #ifndef GGML_WEBGPU_SHADER_LIB_HPP
2
2
  #define GGML_WEBGPU_SHADER_LIB_HPP
3
3
 
4
+ #include "ggml-impl.h"
4
5
  #include "ggml-wgsl-shaders.hpp"
5
6
  #include "ggml.h"
6
7
  #include "pre_wgsl.hpp"
@@ -17,6 +18,9 @@
17
18
  #define GGML_WEBGPU_F32_SIZE_BYTES 4
18
19
  #define GGML_WEBGPU_I32_SIZE_BYTES 4
19
20
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
21
+ #define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
22
+ #define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
23
+ #define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
20
24
  #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
21
25
  // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
22
26
  #define GGML_WEBGPU_KV_SEQ_PAD 256u
@@ -26,38 +30,32 @@
26
30
  // Matrix multiplication parameters
27
31
 
28
32
  // Register tiling parameters
29
- #define WEBGPU_MUL_MAT_TILE_M 8
30
- #define WEBGPU_MUL_MAT_TILE_N 8
31
- #define WEBGPU_MUL_MAT_WG_SIZE_M 8
32
- #define WEBGPU_MUL_MAT_WG_SIZE_N 8
33
- #define WEBGPU_MUL_MAT_TILE_K 32
33
+ #define WEBGPU_MUL_MAT_TILE_M 4
34
+ #define WEBGPU_MUL_MAT_TILE_N 4
35
+ #define WEBGPU_MUL_MAT_WG_SIZE_M 8
36
+ #define WEBGPU_MUL_MAT_WG_SIZE_N 8
37
+ #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
38
+ #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
34
39
 
35
40
  // Subgroup matrix parameters
36
41
  // The number of subgroups in the M dimension
37
- #define WEBGPU_MUL_MAT_SUBGROUP_M 2
42
+ #define WEBGPU_MUL_MAT_SUBGROUP_M 2
38
43
  // The number of subgroups in the N dimension
39
- #define WEBGPU_MUL_MAT_SUBGROUP_N 2
44
+ #define WEBGPU_MUL_MAT_SUBGROUP_N 4
40
45
  // The number of subgroup matrices each subgroup accumulates over
41
- #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
42
- #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
46
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
47
+ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
48
+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
49
+ #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
43
50
 
44
51
  // Matrix-vector multiplication parameters
45
52
  #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
46
53
 
47
- // Must be multiple of 4 to work with vectorized paths, and must divide
48
- // mul_mat_vec wg size
49
- #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
50
- #define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
54
+ #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
55
+ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
56
+ #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
51
57
 
52
- #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
53
- #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
54
-
55
- // Requires 32 threads per output (wg_size/outputs_per_wg == 32)
56
- #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
57
- // Requires at least two (and multiple of 2) k-quant blocks per tile
58
- #define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
59
-
60
- // default size for legacy matrix multiplication
58
+ // default size for reg-tile matrix multiplication
61
59
  #define WEBGPU_MUL_MAT_WG_SIZE 256
62
60
 
63
61
  // Same hash combine function as in boost
@@ -65,24 +63,41 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
65
63
  seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
66
64
  }
67
65
 
66
+ // Calculates base address of a tensor ignoring the fake base pointer
67
+ inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
68
+ const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
69
+ return (uintptr_t) base_tensor->data + tensor->view_offs;
70
+ }
71
+
72
+ inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
73
+ return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
74
+ }
75
+
76
+ inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
77
+ return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
78
+ ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
79
+ }
80
+
68
81
  struct ggml_webgpu_shader_lib_context {
69
82
  ggml_tensor * src0;
70
83
  ggml_tensor * src1;
71
84
  ggml_tensor * src2;
72
85
  ggml_tensor * src3;
73
86
  ggml_tensor * src4;
87
+ ggml_tensor * src5;
74
88
  ggml_tensor * dst;
75
89
 
76
- uint32_t max_wg_size;
77
- size_t wg_mem_limit_bytes = 0;
78
- bool inplace = false;
79
- bool overlap = false;
80
- bool src_overlap = false;
81
- bool supports_subgroup_matrix = false;
82
- uint32_t sg_mat_m = 0;
83
- uint32_t sg_mat_n = 0;
84
- uint32_t sg_mat_k = 0;
85
- uint32_t max_subgroup_size = 0;
90
+ uint32_t max_wg_size;
91
+ size_t wg_mem_limit_bytes = 0;
92
+ bool supports_subgroups = false;
93
+ bool supports_subgroup_matrix = false;
94
+ uint32_t sg_mat_m = 0;
95
+ uint32_t sg_mat_n = 0;
96
+ uint32_t sg_mat_k = 0;
97
+ uint32_t min_subgroup_size = 0;
98
+ uint32_t max_subgroup_size = 0;
99
+ bool supports_dot_product = false;
100
+ std::string vendor;
86
101
  };
87
102
 
88
103
  struct webgpu_pipeline {
@@ -93,6 +108,51 @@ struct webgpu_pipeline {
93
108
 
94
109
  struct ggml_webgpu_generic_shader_decisions {
95
110
  uint32_t wg_size = 0;
111
+ bool inplace = false;
112
+ };
113
+
114
+ struct ggml_webgpu_binary_shader_decisions {
115
+ uint32_t wg_size = 0;
116
+ bool inplace = false;
117
+ bool overlap = false;
118
+ bool src_overlap = false;
119
+ };
120
+
121
+ struct ggml_webgpu_processed_shader {
122
+ std::string wgsl;
123
+ std::string variant;
124
+ std::shared_ptr<void> decisions;
125
+ };
126
+
127
+ struct ggml_webgpu_ssm_conv_shader_decisions {
128
+ uint32_t block_size;
129
+ uint32_t tokens_per_wg;
130
+ };
131
+
132
+ struct ggml_webgpu_ssm_scan_pipeline_key {
133
+ int type;
134
+ int d_state;
135
+ bool xbc_overlap;
136
+
137
+ bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
138
+ return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
139
+ }
140
+ };
141
+
142
+ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
143
+ size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
144
+ size_t seed = 0;
145
+ ggml_webgpu_hash_combine(seed, key.type);
146
+ ggml_webgpu_hash_combine(seed, key.d_state);
147
+ ggml_webgpu_hash_combine(seed, key.xbc_overlap);
148
+ return seed;
149
+ }
150
+ };
151
+
152
+ struct ggml_webgpu_ssm_scan_shader_decisions {
153
+ uint32_t wg_size;
154
+ uint32_t tokens_per_tile;
155
+ bool xbc_overlap = false;
96
156
  };
97
157
 
98
158
  /** Argsort **/
@@ -109,9 +169,11 @@ struct ggml_webgpu_set_rows_pipeline_key {
109
169
  int dst_type;
110
170
  int vec4;
111
171
  int i64_idx;
172
+ int pair_blocks;
112
173
 
113
174
  bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
114
- return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
175
+ return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
176
+ pair_blocks == other.pair_blocks;
115
177
  }
116
178
  };
117
179
 
@@ -121,6 +183,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
121
183
  ggml_webgpu_hash_combine(seed, key.dst_type);
122
184
  ggml_webgpu_hash_combine(seed, key.vec4);
123
185
  ggml_webgpu_hash_combine(seed, key.i64_idx);
186
+ ggml_webgpu_hash_combine(seed, key.pair_blocks);
124
187
  return seed;
125
188
  }
126
189
  };
@@ -128,9 +191,30 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
128
191
  struct ggml_webgpu_set_rows_shader_decisions {
129
192
  bool vec4;
130
193
  bool i64_idx;
194
+ bool pair_blocks;
131
195
  uint32_t wg_size;
132
196
  };
133
197
 
198
+ /** Set **/
199
+
200
+ struct ggml_webgpu_set_pipeline_key {
201
+ ggml_type type;
202
+ bool inplace;
203
+
204
+ bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
205
+ return type == other.type && inplace == other.inplace;
206
+ }
207
+ };
208
+
209
+ struct ggml_webgpu_set_pipeline_key_hash {
210
+ size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
211
+ size_t seed = 0;
212
+ ggml_webgpu_hash_combine(seed, key.type);
213
+ ggml_webgpu_hash_combine(seed, key.inplace);
214
+ return seed;
215
+ }
216
+ };
217
+
134
218
  /** Get Rows **/
135
219
 
136
220
  struct ggml_webgpu_get_rows_pipeline_key {
@@ -151,6 +235,59 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
151
235
  }
152
236
  };
153
237
 
238
+ /** Row Norm **/
239
+
240
+ struct ggml_webgpu_row_norm_pipeline_key {
241
+ ggml_op op;
242
+ ggml_type src_type;
243
+ ggml_type dst_type;
244
+ bool inplace;
245
+
246
+ bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
247
+ return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
248
+ }
249
+ };
250
+
251
+ struct ggml_webgpu_row_norm_pipeline_key_hash {
252
+ size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
253
+ size_t seed = 0;
254
+ ggml_webgpu_hash_combine(seed, key.op);
255
+ ggml_webgpu_hash_combine(seed, key.src_type);
256
+ ggml_webgpu_hash_combine(seed, key.dst_type);
257
+ ggml_webgpu_hash_combine(seed, key.inplace);
258
+ return seed;
259
+ }
260
+ };
261
+
262
+ /** RMS_NORM + MUL **/
263
+
264
+ struct ggml_webgpu_rms_norm_mul_pipeline_key {
265
+ bool inplace; // rn_src == dst
266
+ bool overlap; // mul_src == dst
267
+ bool src_overlap; // rn_src == mul_src
268
+
269
+ bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
270
+ return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
271
+ }
272
+ };
273
+
274
+ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
275
+ size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
276
+ size_t seed = 0;
277
+ ggml_webgpu_hash_combine(seed, key.inplace);
278
+ ggml_webgpu_hash_combine(seed, key.overlap);
279
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
280
+ return seed;
281
+ }
282
+ };
283
+
284
+ struct ggml_webgpu_rms_norm_mul_shader_decisions {
285
+ uint32_t wg_size = 0;
286
+ bool inplace = false;
287
+ bool overlap = false;
288
+ bool src_overlap = false;
289
+ };
290
+
154
291
  /** Pad **/
155
292
  struct ggml_webgpu_pad_pipeline_key {
156
293
  bool circular;
@@ -166,6 +303,107 @@ struct ggml_webgpu_pad_pipeline_key_hash {
166
303
  }
167
304
  };
168
305
 
306
+ /** Solve Tri **/
307
+ struct ggml_webgpu_solve_tri_pipeline_key {
308
+ int type;
309
+ int n;
310
+ int k;
311
+
312
+ bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
313
+ return type == other.type && n == other.n && k == other.k;
314
+ }
315
+ };
316
+
317
+ struct ggml_webgpu_solve_tri_pipeline_key_hash {
318
+ size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
319
+ size_t seed = 0;
320
+ ggml_webgpu_hash_combine(seed, key.type);
321
+ ggml_webgpu_hash_combine(seed, key.n);
322
+ ggml_webgpu_hash_combine(seed, key.k);
323
+ return seed;
324
+ }
325
+ };
326
+
327
+ /** SSM Conv **/
328
+ struct ggml_webgpu_ssm_conv_pipeline_key {
329
+ int type;
330
+ int vectorized;
331
+
332
+ bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
333
+ return type == other.type && vectorized == other.vectorized;
334
+ }
335
+ };
336
+
337
+ /** CONV 2D */
338
+ struct ggml_webgpu_conv2d_pipeline_key {
339
+ ggml_type weight_type;
340
+ ggml_type input_type;
341
+ ggml_type output_type;
342
+
343
+ bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
344
+ return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
345
+ }
346
+ };
347
+
348
+ struct ggml_webgpu_conv2d_pipeline_key_hash {
349
+ size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
350
+ size_t seed = 0;
351
+ ggml_webgpu_hash_combine(seed, key.weight_type);
352
+ ggml_webgpu_hash_combine(seed, key.input_type);
353
+ ggml_webgpu_hash_combine(seed, key.output_type);
354
+ return seed;
355
+ }
356
+ };
357
+
358
+ /** Im2Col **/
359
+ struct ggml_webgpu_im2col_pipeline_key {
360
+ ggml_type input_type;
361
+ ggml_type output_type;
362
+
363
+ bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
364
+ return input_type == other.input_type && output_type == other.output_type;
365
+ }
366
+ };
367
+
368
+ struct ggml_webgpu_im2col_pipeline_key_hash {
369
+ size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
370
+ size_t seed = 0;
371
+ ggml_webgpu_hash_combine(seed, key.input_type);
372
+ ggml_webgpu_hash_combine(seed, key.output_type);
373
+ return seed;
374
+ }
375
+ };
376
+
377
+ /** Gated Delta Net **/
378
+ struct ggml_webgpu_gated_delta_net_pipeline_key {
379
+ int type;
380
+ int s_v;
381
+ int kda;
382
+
383
+ bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
384
+ return type == other.type && s_v == other.s_v && kda == other.kda;
385
+ }
386
+ };
387
+
388
+ struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
389
+ size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
390
+ size_t seed = 0;
391
+ ggml_webgpu_hash_combine(seed, key.type);
392
+ ggml_webgpu_hash_combine(seed, key.s_v);
393
+ ggml_webgpu_hash_combine(seed, key.kda);
394
+ return seed;
395
+ }
396
+ };
397
+
398
+ struct ggml_webgpu_ssm_conv_pipeline_key_hash {
399
+ size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
400
+ size_t seed = 0;
401
+ ggml_webgpu_hash_combine(seed, key.type);
402
+ ggml_webgpu_hash_combine(seed, key.vectorized);
403
+ return seed;
404
+ }
405
+ };
406
+
169
407
  /** Scale **/
170
408
 
171
409
  struct ggml_webgpu_scale_pipeline_key {
@@ -182,18 +420,47 @@ struct ggml_webgpu_scale_pipeline_key_hash {
182
420
  }
183
421
  };
184
422
 
423
+ /** Upscale **/
424
+
425
+ struct ggml_webgpu_upscale_pipeline_key {
426
+ ggml_type input_type;
427
+ ggml_type output_type;
428
+ uint32_t base_mode;
429
+ bool antialias;
430
+
431
+ bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
432
+ return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
433
+ antialias == other.antialias;
434
+ }
435
+ };
436
+
437
+ struct ggml_webgpu_upscale_pipeline_key_hash {
438
+ size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
439
+ size_t seed = 0;
440
+ ggml_webgpu_hash_combine(seed, key.input_type);
441
+ ggml_webgpu_hash_combine(seed, key.output_type);
442
+ ggml_webgpu_hash_combine(seed, key.base_mode);
443
+ ggml_webgpu_hash_combine(seed, key.antialias);
444
+ return seed;
445
+ }
446
+ };
447
+
185
448
  /** Concat **/
186
449
 
187
450
  struct ggml_webgpu_concat_pipeline_key {
188
- int type;
451
+ int type;
452
+ bool src_overlap;
189
453
 
190
- bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
454
+ bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
455
+ return type == other.type && src_overlap == other.src_overlap;
456
+ }
191
457
  };
192
458
 
193
459
  struct ggml_webgpu_concat_pipeline_key_hash {
194
460
  size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
195
461
  size_t seed = 0;
196
462
  ggml_webgpu_hash_combine(seed, key.type);
463
+ ggml_webgpu_hash_combine(seed, key.src_overlap);
197
464
  return seed;
198
465
  }
199
466
  };
@@ -241,16 +508,34 @@ struct ggml_webgpu_binary_pipeline_key_hash {
241
508
  }
242
509
  };
243
510
 
511
+ /* Add_Id */
512
+
513
+ struct ggml_webgpu_add_id_pipeline_key {
514
+ bool inplace;
515
+
516
+ bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
517
+ };
518
+
519
+ struct ggml_webgpu_add_id_pipeline_key_hash {
520
+ size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
521
+ size_t seed = 0;
522
+ ggml_webgpu_hash_combine(seed, key.inplace);
523
+ return seed;
524
+ }
525
+ };
526
+
244
527
  /** Unary **/
245
528
 
246
529
  struct ggml_webgpu_unary_pipeline_key {
247
- int type;
248
- int op;
249
- bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
250
- bool inplace;
530
+ int type;
531
+ int op;
532
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
533
+ bool inplace;
534
+ ggml_tri_type ttype; // only used for GGML_OP_TRI
251
535
 
252
536
  bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
253
- return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
537
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
538
+ ttype == other.ttype;
254
539
  }
255
540
  };
256
541
 
@@ -261,58 +546,285 @@ struct ggml_webgpu_unary_pipeline_key_hash {
261
546
  ggml_webgpu_hash_combine(seed, key.op);
262
547
  ggml_webgpu_hash_combine(seed, key.is_unary);
263
548
  ggml_webgpu_hash_combine(seed, key.inplace);
549
+ ggml_webgpu_hash_combine(seed, key.ttype);
264
550
  return seed;
265
551
  }
266
552
  };
267
553
 
268
554
  /** FlashAttention */
269
555
 
270
- struct ggml_webgpu_flash_attn_pipeline_key {
271
- ggml_type kv_type;
556
+ struct ggml_webgpu_flash_attn_common_pipeline_key {
557
+ ggml_type q_type;
558
+ ggml_type k_type;
559
+ ggml_type v_type;
560
+ ggml_type dst_type;
272
561
  uint32_t head_dim_qk;
273
562
  uint32_t head_dim_v;
274
563
  bool kv_direct;
564
+ bool kv_overlap;
275
565
  bool has_mask;
276
566
  bool has_sinks;
277
567
  bool uses_logit_softcap;
278
568
 
569
+ bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
570
+ return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
571
+ dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
572
+ kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
573
+ has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
574
+ }
575
+ };
576
+
577
+ inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
578
+ const ggml_webgpu_flash_attn_common_pipeline_key & key) {
579
+ ggml_webgpu_hash_combine(seed, key.q_type);
580
+ ggml_webgpu_hash_combine(seed, key.k_type);
581
+ ggml_webgpu_hash_combine(seed, key.v_type);
582
+ ggml_webgpu_hash_combine(seed, key.dst_type);
583
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
584
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
585
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
586
+ ggml_webgpu_hash_combine(seed, key.kv_overlap);
587
+ ggml_webgpu_hash_combine(seed, key.has_mask);
588
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
589
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
590
+ }
591
+
592
+ struct ggml_webgpu_flash_attn_vec_pipeline_key {
593
+ ggml_webgpu_flash_attn_common_pipeline_key common;
594
+
595
+ bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
596
+ };
597
+
598
+ struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
599
+ size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
600
+ size_t seed = 0;
601
+ ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
602
+ return seed;
603
+ }
604
+ };
605
+
606
+ struct ggml_webgpu_flash_attn_pipeline_key {
607
+ ggml_webgpu_flash_attn_common_pipeline_key common;
608
+ bool use_sg_matrix;
609
+
279
610
  bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
280
- return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
281
- kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
282
- uses_logit_softcap == other.uses_logit_softcap;
611
+ return common == other.common && use_sg_matrix == other.use_sg_matrix;
283
612
  }
284
613
  };
285
614
 
286
615
  struct ggml_webgpu_flash_attn_pipeline_key_hash {
287
616
  size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
288
617
  size_t seed = 0;
289
- ggml_webgpu_hash_combine(seed, key.kv_type);
290
- ggml_webgpu_hash_combine(seed, key.head_dim_qk);
618
+ ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
619
+ ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
620
+ return seed;
621
+ }
622
+ };
623
+
624
+ struct ggml_webgpu_flash_attn_vec_decisions {
625
+ uint32_t kv_tile = 0;
626
+ uint32_t wg_size = 0;
627
+ };
628
+
629
+ struct ggml_webgpu_flash_attn_decisions {
630
+ bool use_sg_matrix = false;
631
+ uint32_t q_tile = 0;
632
+ uint32_t kv_tile = 0;
633
+ uint32_t wg_size = 0;
634
+ };
635
+
636
+ inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
637
+ inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
638
+
639
+ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
640
+ constexpr uintptr_t ptr_base_addr = 0x1000u;
641
+ const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
642
+ return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
643
+ }
644
+
645
+ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
646
+ const uint32_t offset_elems =
647
+ (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
648
+ ggml_type_size(K->type));
649
+ return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
650
+ }
651
+
652
+ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
653
+ const ggml_tensor * V,
654
+ size_t storage_offset_alignment) {
655
+ return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
656
+ ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
657
+ }
658
+
659
+ inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
660
+ const ggml_tensor * K,
661
+ const ggml_tensor * V,
662
+ uint32_t kv_direct_align) {
663
+ return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
664
+ (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
665
+ }
666
+
667
+ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
668
+ const ggml_webgpu_shader_lib_context & context,
669
+ uint32_t kv_direct_align) {
670
+ ggml_webgpu_flash_attn_common_pipeline_key key = {};
671
+ key.q_type = context.src0->type;
672
+ key.k_type = context.src1->type;
673
+ key.v_type = context.src2->type;
674
+ key.dst_type = context.dst->type;
675
+ key.head_dim_qk = (uint32_t) context.src0->ne[0];
676
+ key.head_dim_v = (uint32_t) context.src2->ne[0];
677
+ key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
678
+ key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
679
+ key.has_mask = context.src3 != nullptr;
680
+ key.has_sinks = context.src4 != nullptr;
681
+ key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
682
+ return key;
683
+ }
684
+
685
+ inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
686
+ const ggml_webgpu_flash_attn_common_pipeline_key & key,
687
+ std::string & variant,
688
+ uint32_t q_tile,
689
+ uint32_t kv_tile,
690
+ uint32_t wg_size) {
691
+ std::vector<std::string> defines;
692
+
693
+ switch (key.k_type) {
694
+ case GGML_TYPE_F32:
695
+ defines.push_back("K_F32");
696
+ break;
697
+ case GGML_TYPE_F16:
698
+ defines.push_back("K_F16");
699
+ break;
700
+ case GGML_TYPE_Q4_0:
701
+ defines.push_back("K_Q4_0");
702
+ break;
703
+ case GGML_TYPE_Q8_0:
704
+ defines.push_back("K_Q8_0");
705
+ break;
706
+ default:
707
+ GGML_ABORT("Unsupported K type for flash attention shader");
708
+ }
709
+ variant += std::string("_k") + ggml_type_name(key.k_type);
710
+
711
+ switch (key.v_type) {
712
+ case GGML_TYPE_F32:
713
+ defines.push_back("V_F32");
714
+ break;
715
+ case GGML_TYPE_F16:
716
+ defines.push_back("V_F16");
717
+ break;
718
+ case GGML_TYPE_Q4_0:
719
+ defines.push_back("V_Q4_0");
720
+ break;
721
+ case GGML_TYPE_Q8_0:
722
+ defines.push_back("V_Q8_0");
723
+ break;
724
+ default:
725
+ GGML_ABORT("Unsupported V type for flash attention shader");
726
+ }
727
+ variant += std::string("_v") + ggml_type_name(key.v_type);
728
+
729
+ switch (key.q_type) {
730
+ case GGML_TYPE_F32:
731
+ defines.push_back("Q_F32");
732
+ break;
733
+ case GGML_TYPE_F16:
734
+ defines.push_back("Q_F16");
735
+ break;
736
+ default:
737
+ GGML_ABORT("Unsupported Q type for flash attention shader");
738
+ }
739
+ variant += std::string("_q") + ggml_type_name(key.q_type);
740
+
741
+ switch (key.dst_type) {
742
+ case GGML_TYPE_F32:
743
+ defines.push_back("DST_F32");
744
+ break;
745
+ case GGML_TYPE_F16:
746
+ defines.push_back("DST_F16");
747
+ break;
748
+ default:
749
+ GGML_ABORT("Unsupported dst type for flash attention shader");
750
+ }
751
+ variant += std::string("_dst") + ggml_type_name(key.dst_type);
752
+
753
+ if (key.has_mask) {
754
+ defines.push_back("MASK");
755
+ variant += "_mask";
756
+ }
757
+ if (key.has_sinks) {
758
+ defines.push_back("SINKS");
759
+ variant += "_sinks";
760
+ }
761
+ if (key.uses_logit_softcap) {
762
+ defines.push_back("LOGIT_SOFTCAP");
763
+ variant += "_lgsc";
764
+ }
765
+ if (key.kv_direct) {
766
+ defines.push_back("KV_DIRECT");
767
+ variant += "_kvdirect";
768
+ }
769
+ if (key.kv_overlap) {
770
+ defines.push_back("KV_OVERLAP");
771
+ variant += "_kv_overlap";
772
+ }
773
+
774
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
775
+ variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
776
+
777
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
778
+ variant += std::string("_hsv") + std::to_string(key.head_dim_v);
779
+
780
+ defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
781
+ defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
782
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
783
+
784
+ if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
785
+ defines.push_back("U32_DEQUANT_HELPERS");
786
+ }
787
+
788
+ return defines;
789
+ }
790
+
791
+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
792
+ uint32_t head_dim_v;
793
+ uint32_t wg_size;
794
+ ggml_type dst_type;
795
+ };
796
+
797
+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
798
+ size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
799
+ size_t seed = 0;
291
800
  ggml_webgpu_hash_combine(seed, key.head_dim_v);
292
- ggml_webgpu_hash_combine(seed, key.kv_direct);
293
- ggml_webgpu_hash_combine(seed, key.has_mask);
294
- ggml_webgpu_hash_combine(seed, key.has_sinks);
295
- ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
801
+ ggml_webgpu_hash_combine(seed, key.wg_size);
802
+ ggml_webgpu_hash_combine(seed, key.dst_type);
296
803
  return seed;
297
804
  }
298
805
  };
299
806
 
300
- struct ggml_webgpu_flash_attn_shader_lib_context {
301
- ggml_webgpu_flash_attn_pipeline_key key;
302
- uint32_t sg_mat_m;
303
- uint32_t sg_mat_n;
304
- uint32_t sg_mat_k;
305
- size_t wg_mem_limit_bytes;
306
- uint32_t max_subgroup_size;
807
+ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
808
+ const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
809
+ return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
810
+ }
811
+
812
+ struct ggml_webgpu_flash_attn_blk_pipeline_key {
813
+ uint32_t kv_tile;
814
+
815
+ bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
307
816
  };
308
817
 
309
- struct ggml_webgpu_flash_attn_shader_decisions {
310
- uint32_t q_tile = 0;
311
- uint32_t kv_tile = 0;
312
- uint32_t wg_size = 0;
818
+ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
819
+ size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
820
+ size_t seed = 0;
821
+ ggml_webgpu_hash_combine(seed, key.kv_tile);
822
+ return seed;
823
+ }
313
824
  };
314
825
 
315
- // This is exposed because it's necessary in supports_op
826
+ // Note: this will slightly overestimate memory usage for vec path
827
+ // since row_max and exp_sum shmem are not needed.
316
828
  inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
317
829
  uint32_t kv_tile,
318
830
  uint32_t head_dim_qk,
@@ -322,47 +834,82 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
322
834
  const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
323
835
  size_t f16_elems = 0;
324
836
  size_t f32_elems = 0;
325
- f16_elems += q_tile * head_dim_qk; // q_shmem
837
+
838
+ f32_elems += q_tile * head_dim_qk; // q_shmem
326
839
  if (!kv_direct) {
327
- f16_elems += kv_tile * max_head_dim; // kv_shmem
840
+ f32_elems += kv_tile * max_head_dim; // kv_shmem
328
841
  }
329
- f16_elems += q_tile * head_dim_v; // o_shmem
842
+ f32_elems += q_tile * head_dim_v; // o_shmem
330
843
  if (has_mask) {
331
- f16_elems += q_tile * kv_tile; // mask_shmem
844
+ f32_elems += q_tile * kv_tile; // mask_shmem
332
845
  }
333
- f16_elems += q_tile * kv_tile; // inter_shmem
846
+ f32_elems += q_tile * kv_tile; // inter_shmem
334
847
  f32_elems += q_tile; // row_max_shmem
335
848
  f32_elems += q_tile; // exp_sum_shmem
336
849
  return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
337
850
  }
338
851
 
339
- /** Matrix Multiplication **/
340
-
341
- struct ggml_webgpu_legacy_mul_mat_pipeline_key {
342
- ggml_type src0_type;
343
- ggml_type src1_type;
344
-
345
- bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
346
- return src0_type == other.src0_type && src1_type == other.src1_type;
852
+ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
853
+ uint32_t q_tile,
854
+ uint32_t kv_granularity,
855
+ uint32_t head_dim_qk,
856
+ uint32_t head_dim_v,
857
+ bool has_mask,
858
+ bool kv_direct) {
859
+ const size_t base_q_bytes =
860
+ ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
861
+ if (limit_bytes <= base_q_bytes) {
862
+ return 0;
347
863
  }
348
- };
864
+ const size_t one_kv_bytes =
865
+ ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
866
+ const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
867
+ if (bytes_per_kv == 0) {
868
+ return 0;
869
+ }
870
+ const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
871
+ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
872
+ }
349
873
 
350
- struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
351
- size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
352
- size_t seed = 0;
353
- ggml_webgpu_hash_combine(seed, key.src0_type);
354
- ggml_webgpu_hash_combine(seed, key.src1_type);
355
- return seed;
874
+ inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
875
+ uint32_t head_dim_qk,
876
+ uint32_t head_dim_v,
877
+ bool has_mask,
878
+ bool kv_direct) {
879
+ const uint32_t max_kv_tile =
880
+ ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
881
+ GGML_ASSERT(max_kv_tile > 0);
882
+
883
+ uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
884
+ if (kv_direct) {
885
+ kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
886
+ while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
887
+ kv_tile -= 1u;
888
+ }
356
889
  }
357
- };
890
+
891
+ return kv_tile;
892
+ }
893
+
894
+ inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
895
+ uint32_t sg_mat_k,
896
+ uint32_t sg_mat_n,
897
+ const ggml_tensor * Q,
898
+ const ggml_tensor * V) {
899
+ return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
900
+ }
901
+
902
+ /** Matrix Multiplication **/
358
903
 
359
904
  struct ggml_webgpu_mul_mat_vec_pipeline_key {
360
905
  ggml_type src0_type;
361
906
  ggml_type src1_type;
362
907
  int vectorized;
908
+ bool use_mmvq;
363
909
 
364
910
  bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
365
- return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
911
+ return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
912
+ use_mmvq == other.use_mmvq;
366
913
  }
367
914
  };
368
915
 
@@ -372,17 +919,31 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
372
919
  ggml_webgpu_hash_combine(seed, key.src0_type);
373
920
  ggml_webgpu_hash_combine(seed, key.src1_type);
374
921
  ggml_webgpu_hash_combine(seed, key.vectorized);
922
+ ggml_webgpu_hash_combine(seed, key.use_mmvq);
375
923
  return seed;
376
924
  }
377
925
  };
378
926
 
379
927
  struct ggml_webgpu_mul_mat_vec_shader_decisions {
380
928
  uint32_t wg_size;
381
- uint32_t tile_k;
382
929
  uint32_t outputs_per_wg;
383
930
  uint32_t vec_size;
384
931
  };
385
932
 
933
+ struct ggml_webgpu_quantize_q8_pipeline_key {
934
+ ggml_type src0_type;
935
+
936
+ bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
937
+ };
938
+
939
+ struct ggml_webgpu_quantize_q8_pipeline_key_hash {
940
+ size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
941
+ size_t seed = 0;
942
+ ggml_webgpu_hash_combine(seed, key.src0_type);
943
+ return seed;
944
+ }
945
+ };
946
+
386
947
  struct ggml_webgpu_mul_mat_pipeline_key {
387
948
  ggml_type src0_type;
388
949
  ggml_type src1_type;
@@ -426,8 +987,152 @@ struct ggml_webgpu_mul_mat_shader_decisions {
426
987
  uint32_t mul_mat_wg_size;
427
988
  };
428
989
 
429
- class ggml_webgpu_shader_lib {
430
- wgpu::Device device;
990
+ /** MUL_MAT_ID **/
991
+
992
+ struct ggml_webgpu_mul_mat_id_pipeline_key {
993
+ ggml_type src0_type;
994
+ ggml_type src1_type;
995
+ uint32_t n_experts;
996
+ int vectorized;
997
+
998
+ bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
999
+ return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
1000
+ vectorized == other.vectorized;
1001
+ }
1002
+ };
1003
+
1004
+ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
1005
+ size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
1006
+ size_t seed = 0;
1007
+ ggml_webgpu_hash_combine(seed, key.src0_type);
1008
+ ggml_webgpu_hash_combine(seed, key.src1_type);
1009
+ ggml_webgpu_hash_combine(seed, key.n_experts);
1010
+ ggml_webgpu_hash_combine(seed, key.vectorized);
1011
+ return seed;
1012
+ }
1013
+ };
1014
+
1015
+ /** Cpy **/
1016
+
1017
+ struct ggml_webgpu_cpy_pipeline_key {
1018
+ ggml_type src_type;
1019
+ ggml_type dst_type;
1020
+
1021
+ bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
1022
+ return src_type == other.src_type && dst_type == other.dst_type;
1023
+ }
1024
+ };
1025
+
1026
+ struct ggml_webgpu_cpy_pipeline_key_hash {
1027
+ size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
1028
+ size_t seed = 0;
1029
+ ggml_webgpu_hash_combine(seed, key.src_type);
1030
+ ggml_webgpu_hash_combine(seed, key.dst_type);
1031
+ return seed;
1032
+ }
1033
+ };
1034
+
1035
+ /** Glu **/
1036
+
1037
+ struct ggml_webgpu_glu_pipeline_key {
1038
+ ggml_glu_op glu_op;
1039
+ ggml_type type;
1040
+ bool split;
1041
+
1042
+ bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
1043
+ return glu_op == other.glu_op && type == other.type && split == other.split;
1044
+ }
1045
+ };
1046
+
1047
+ struct ggml_webgpu_glu_pipeline_key_hash {
1048
+ size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
1049
+ size_t seed = 0;
1050
+ ggml_webgpu_hash_combine(seed, key.glu_op);
1051
+ ggml_webgpu_hash_combine(seed, key.type);
1052
+ ggml_webgpu_hash_combine(seed, key.split);
1053
+ return seed;
1054
+ }
1055
+ };
1056
+
1057
+ /** Rope **/
1058
+
1059
+ struct ggml_webgpu_rope_pipeline_key {
1060
+ ggml_type type;
1061
+ bool inplace;
1062
+ bool has_ff;
1063
+
1064
+ bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
1065
+ return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
1066
+ }
1067
+ };
1068
+
1069
+ struct ggml_webgpu_rope_pipeline_key_hash {
1070
+ size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
1071
+ size_t seed = 0;
1072
+ ggml_webgpu_hash_combine(seed, key.type);
1073
+ ggml_webgpu_hash_combine(seed, key.inplace);
1074
+ ggml_webgpu_hash_combine(seed, key.has_ff);
1075
+ return seed;
1076
+ }
1077
+ };
1078
+
1079
+ /** SoftMax **/
1080
+
1081
+ struct ggml_webgpu_soft_max_pipeline_key {
1082
+ ggml_type mask_type;
1083
+ bool has_mask;
1084
+ bool has_sink;
1085
+ bool inplace;
1086
+
1087
+ bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
1088
+ return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
1089
+ inplace == other.inplace;
1090
+ }
1091
+ };
1092
+
1093
+ struct ggml_webgpu_soft_max_pipeline_key_hash {
1094
+ size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
1095
+ size_t seed = 0;
1096
+ ggml_webgpu_hash_combine(seed, key.mask_type);
1097
+ ggml_webgpu_hash_combine(seed, key.has_mask);
1098
+ ggml_webgpu_hash_combine(seed, key.has_sink);
1099
+ ggml_webgpu_hash_combine(seed, key.inplace);
1100
+ return seed;
1101
+ }
1102
+ };
1103
+
1104
+ /** MMVQ **/
1105
+
1106
+ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
1107
+ const ggml_tensor * src1,
1108
+ bool supports_dot_product,
1109
+ const std::string & vendor) {
1110
+ if (src1->ne[1] == 1) {
1111
+ bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
1112
+ if (supports_dp4a && supports_dot_product) {
1113
+ switch (src1->type) {
1114
+ case GGML_TYPE_F32:
1115
+ switch (src0->type) {
1116
+ case GGML_TYPE_Q4_0:
1117
+ case GGML_TYPE_Q4_1:
1118
+ case GGML_TYPE_Q8_0:
1119
+ case GGML_TYPE_Q2_K:
1120
+ case GGML_TYPE_Q4_K:
1121
+ return src0->ne[0] % 4 == 0;
1122
+ default:
1123
+ break;
1124
+ }
1125
+ break;
1126
+ default:
1127
+ break;
1128
+ }
1129
+ }
1130
+ }
1131
+ return false;
1132
+ }
1133
+
1134
+ class ggml_webgpu_shader_lib {
1135
+ wgpu::Device device;
431
1136
  pre_wgsl::Preprocessor preprocessor;
432
1137
 
433
1138
  std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
@@ -435,33 +1140,81 @@ class ggml_webgpu_shader_lib {
435
1140
  std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
436
1141
  std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
437
1142
  std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
1143
+ std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
1144
+ row_norm_pipelines; // op/inplace
1145
+
438
1146
  std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
439
- get_rows_pipelines; // src_type, vectorized
1147
+ get_rows_pipelines; // src_type, vectorized
440
1148
  std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
441
- unary_pipelines; // type/op/inplace
1149
+ unary_pipelines; // type/op/inplace
442
1150
  std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
443
- scale_pipelines; // inplace
1151
+ scale_pipelines; // inplace
1152
+ std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
1153
+ solve_tri_pipelines; // type
1154
+ std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
1155
+ ssm_conv_pipelines; // type/vectorized
1156
+ std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
1157
+ ssm_scan_pipelines; // type/d_state
1158
+ std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
1159
+ webgpu_pipeline,
1160
+ ggml_webgpu_gated_delta_net_pipeline_key_hash>
1161
+ gated_delta_net_pipelines; // type/S_v/kda
444
1162
  std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
445
- pad_pipelines; // circular/non-circular
1163
+ pad_pipelines; // circular/non-circular
446
1164
  std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
447
- binary_pipelines; // type/op/inplace/overlap
1165
+ binary_pipelines; // type/op/inplace/overlap/src_overlap
1166
+ std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
1167
+ add_id_pipelines; // inplace
448
1168
  std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
449
- concat_pipelines; // type
1169
+ concat_pipelines; // type
450
1170
  std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
451
- repeat_pipelines; // type
1171
+ repeat_pipelines; // type
1172
+ std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
1173
+ webgpu_pipeline,
1174
+ ggml_webgpu_flash_attn_vec_pipeline_key_hash>
1175
+ flash_attn_vec_pipelines;
452
1176
  std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
453
1177
  flash_attn_pipelines;
454
- std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
1178
+ std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
455
1179
  webgpu_pipeline,
456
- ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
457
- mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
1180
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
1181
+ flash_attn_vec_reduce_pipelines;
1182
+ std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
1183
+ webgpu_pipeline,
1184
+ ggml_webgpu_flash_attn_blk_pipeline_key_hash>
1185
+ flash_attn_blk_pipelines;
458
1186
  std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
459
- mul_mat_vec_pipelines; // fast mat-vec (n==1)
1187
+ mul_mat_vec_pipelines; // fast mat-vec (n==1)
460
1188
  std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
461
- mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
1189
+ mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
1190
+ std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
1191
+ quantize_q8_pipelines;
1192
+ std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
1193
+ std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1194
+ mul_mat_id_pipelines; // src0_type/src1_type
1195
+ std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1196
+ mul_mat_id_vec_pipelines; // src0_type/src1_type
462
1197
 
463
1198
  std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
464
1199
  set_rows_pipelines;
1200
+ std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
1201
+ std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
1202
+ std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
1203
+ std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
1204
+ rope_pipelines;
1205
+ std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
1206
+ soft_max_pipelines;
1207
+ std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
1208
+ conv2d_pipelines;
1209
+ std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
1210
+ im2col_pipelines;
1211
+
1212
+ std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
1213
+ webgpu_pipeline,
1214
+ ggml_webgpu_rms_norm_mul_pipeline_key_hash>
1215
+ rms_norm_mul_pipelines;
1216
+ std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
1217
+ upscale_pipelines;
465
1218
 
466
1219
  public:
467
1220
  ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@@ -479,6 +1232,70 @@ class ggml_webgpu_shader_lib {
479
1232
  return sum_rows_pipelines[1];
480
1233
  }
481
1234
 
1235
+ webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
1236
+ ggml_webgpu_row_norm_pipeline_key key = {};
1237
+ key.op = context.dst->op;
1238
+ key.src_type = context.src0->type;
1239
+ key.dst_type = context.dst->type;
1240
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
1241
+
1242
+ auto it = row_norm_pipelines.find(key);
1243
+ if (it != row_norm_pipelines.end()) {
1244
+ return it->second;
1245
+ }
1246
+ std::vector<std::string> defines;
1247
+ std::string variant;
1248
+
1249
+ switch (key.op) {
1250
+ case GGML_OP_RMS_NORM:
1251
+ defines.push_back("RMS_NORM");
1252
+ variant = "rms_norm";
1253
+ break;
1254
+ case GGML_OP_NORM:
1255
+ defines.push_back("NORM");
1256
+ variant = "norm";
1257
+ break;
1258
+ case GGML_OP_L2_NORM:
1259
+ defines.push_back("L2_NORM");
1260
+ variant = "l2_norm";
1261
+ break;
1262
+ default:
1263
+ GGML_ABORT("Unsupported op for row_norm shader");
1264
+ }
1265
+
1266
+ if (key.inplace) {
1267
+ defines.push_back("INPLACE");
1268
+ variant += "_inplace";
1269
+ }
1270
+
1271
+ if (key.src_type == GGML_TYPE_F32) {
1272
+ defines.push_back("SRC_F32");
1273
+ variant += "_src_f32";
1274
+ } else if (key.src_type == GGML_TYPE_F16) {
1275
+ defines.push_back("SRC_F16");
1276
+ variant += "_src_f16";
1277
+ }
1278
+
1279
+ if (key.dst_type == GGML_TYPE_F32) {
1280
+ defines.push_back("DST_F32");
1281
+ variant += "_dst_f32";
1282
+ } else if (key.dst_type == GGML_TYPE_F16) {
1283
+ defines.push_back("DST_F16");
1284
+ variant += "_dst_f16";
1285
+ }
1286
+
1287
+ const uint32_t row_norm_wg_size = 128u;
1288
+ uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
1289
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1290
+ auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
1291
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1292
+ decisions->wg_size = wg_size;
1293
+ decisions->inplace = key.inplace;
1294
+ row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
1295
+ row_norm_pipelines[key].context = decisions;
1296
+ return row_norm_pipelines[key];
1297
+ }
1298
+
482
1299
  webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
483
1300
  bool vec4 = context.src0->ne[0] % 4 == 0;
484
1301
 
@@ -500,9 +1317,13 @@ class ggml_webgpu_shader_lib {
500
1317
  }
501
1318
 
502
1319
  webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
503
- ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
504
- .vec4 = context.src0->ne[0] % 4 == 0,
505
- .i64_idx = context.src1->type == GGML_TYPE_I64 };
1320
+ const bool quantized = ggml_is_quantized(context.dst->type);
1321
+ ggml_webgpu_set_rows_pipeline_key key = {};
1322
+ key.dst_type = context.dst->type;
1323
+ key.vec4 =
1324
+ (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
1325
+ key.i64_idx = context.src1->type == GGML_TYPE_I64;
1326
+ key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
506
1327
 
507
1328
  auto it = set_rows_pipelines.find(key);
508
1329
  if (it != set_rows_pipelines.end()) {
@@ -521,6 +1342,14 @@ class ggml_webgpu_shader_lib {
521
1342
  defines.push_back("DST_F16");
522
1343
  variant += "_dstf16";
523
1344
  break;
1345
+ case GGML_TYPE_Q8_0:
1346
+ defines.push_back("DST_Q8_0");
1347
+ variant += "_dstq8_0";
1348
+ break;
1349
+ case GGML_TYPE_Q4_0:
1350
+ defines.push_back("DST_Q4_0");
1351
+ variant += "_dstq4_0";
1352
+ break;
524
1353
  default:
525
1354
  GGML_ABORT("Unsupported dst type for set_rows shader");
526
1355
  }
@@ -533,19 +1362,68 @@ class ggml_webgpu_shader_lib {
533
1362
  defines.push_back("I64_IDX");
534
1363
  variant += "_i64idx";
535
1364
  }
1365
+ if (key.pair_blocks) {
1366
+ defines.push_back("PAIR_BLOCKS");
1367
+ variant += "_pair_blocks";
1368
+ }
536
1369
 
537
1370
  defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
538
1371
 
539
- auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
540
- auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
1372
+ const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
1373
+ auto processed = preprocessor.preprocess(shader_source, defines);
1374
+ auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
541
1375
  decisions->vec4 = key.vec4;
542
1376
  decisions->i64_idx = key.i64_idx;
1377
+ decisions->pair_blocks = key.pair_blocks;
543
1378
  decisions->wg_size = context.max_wg_size;
544
1379
  set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
545
1380
  set_rows_pipelines[key].context = decisions;
546
1381
  return set_rows_pipelines[key];
547
1382
  }
548
1383
 
1384
+ webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
1385
+ ggml_webgpu_set_pipeline_key key = {};
1386
+ key.type = context.dst->type;
1387
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
1388
+
1389
+ auto it = set_pipelines.find(key);
1390
+ if (it != set_pipelines.end()) {
1391
+ return it->second;
1392
+ }
1393
+
1394
+ std::vector<std::string> defines;
1395
+ std::string variant = "set";
1396
+
1397
+ switch (key.type) {
1398
+ case GGML_TYPE_F32:
1399
+ defines.push_back("TYPE_F32");
1400
+ variant += "_f32";
1401
+ break;
1402
+ case GGML_TYPE_I32:
1403
+ defines.push_back("TYPE_I32");
1404
+ variant += "_i32";
1405
+ break;
1406
+ default:
1407
+ GGML_ABORT("Unsupported type for set shader");
1408
+ }
1409
+
1410
+ if (key.inplace) {
1411
+ defines.push_back("INPLACE");
1412
+ variant += "_inplace";
1413
+ }
1414
+
1415
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1416
+
1417
+ auto processed = preprocessor.preprocess(wgsl_set, defines);
1418
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1419
+ decisions->wg_size = context.max_wg_size;
1420
+ decisions->inplace = key.inplace;
1421
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1422
+ pipeline.context = decisions;
1423
+ set_pipelines[key] = pipeline;
1424
+ return set_pipelines[key];
1425
+ }
1426
+
549
1427
  webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
550
1428
  auto it = cumsum_pipelines.find(1);
551
1429
  if (it != cumsum_pipelines.end()) {
@@ -614,10 +1492,9 @@ class ggml_webgpu_shader_lib {
614
1492
 
615
1493
  webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
616
1494
  const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
617
- ggml_webgpu_get_rows_pipeline_key key = {
618
- .src_type = context.src0->type,
619
- .vectorized = (int) vectorized,
620
- };
1495
+ ggml_webgpu_get_rows_pipeline_key key = {};
1496
+ key.src_type = context.src0->type;
1497
+ key.vectorized = (int) vectorized;
621
1498
 
622
1499
  auto it = get_rows_pipelines.find(key);
623
1500
  if (it != get_rows_pipelines.end()) {
@@ -632,6 +1509,7 @@ class ggml_webgpu_shader_lib {
632
1509
 
633
1510
  switch (key.src_type) {
634
1511
  case GGML_TYPE_F32:
1512
+ defines.push_back("FLOAT_PARALLEL");
635
1513
  if (key.vectorized) {
636
1514
  defines.push_back("F32_VEC");
637
1515
  defines.push_back("SRC_TYPE=vec4<f32>");
@@ -646,6 +1524,7 @@ class ggml_webgpu_shader_lib {
646
1524
  variant += "_f32";
647
1525
  break;
648
1526
  case GGML_TYPE_F16:
1527
+ defines.push_back("FLOAT_PARALLEL");
649
1528
  defines.push_back("F16");
650
1529
  defines.push_back("SRC_TYPE=f16");
651
1530
  defines.push_back("DST_TYPE=f32");
@@ -653,6 +1532,7 @@ class ggml_webgpu_shader_lib {
653
1532
  variant += "_f16";
654
1533
  break;
655
1534
  case GGML_TYPE_I32:
1535
+ defines.push_back("FLOAT_PARALLEL");
656
1536
  defines.push_back("I32");
657
1537
  defines.push_back("SRC_TYPE=i32");
658
1538
  defines.push_back("DST_TYPE=i32");
@@ -664,21 +1544,50 @@ class ggml_webgpu_shader_lib {
664
1544
  std::string type_upper = type_str;
665
1545
  std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
666
1546
 
1547
+ switch (key.src_type) {
1548
+ case GGML_TYPE_Q1_0:
1549
+ case GGML_TYPE_Q4_0:
1550
+ case GGML_TYPE_Q5_0:
1551
+ case GGML_TYPE_Q8_0:
1552
+ case GGML_TYPE_Q3_K:
1553
+ case GGML_TYPE_Q6_K:
1554
+ case GGML_TYPE_IQ2_XXS:
1555
+ case GGML_TYPE_IQ2_XS:
1556
+ case GGML_TYPE_IQ2_S:
1557
+ case GGML_TYPE_IQ3_XXS:
1558
+ case GGML_TYPE_IQ3_S:
1559
+ case GGML_TYPE_IQ1_S:
1560
+ case GGML_TYPE_IQ4_NL:
1561
+ case GGML_TYPE_MXFP4:
1562
+ {
1563
+ // Quantized types using u32 buffers for portability.
1564
+ defines.push_back("SRC_TYPE=u32");
1565
+ defines.push_back("U32_DEQUANT_HELPERS");
1566
+ break;
1567
+ }
1568
+ default:
1569
+ {
1570
+ defines.push_back(std::string("SRC_TYPE=") + type_str);
1571
+ }
1572
+ }
1573
+
667
1574
  defines.push_back("BYTE_HELPERS");
668
1575
  defines.push_back(type_upper + "_T");
669
1576
  defines.push_back(type_upper);
670
1577
  defines.push_back(type_upper + "_SCALE_MIN");
671
1578
  defines.push_back(type_upper + "_TABLES");
672
1579
  defines.push_back(type_upper + "_GRID");
1580
+ defines.push_back(type_upper + "_LUT");
673
1581
 
674
1582
  variant += "_";
675
1583
  variant += type_str;
676
1584
 
677
- defines.push_back(std::string("SRC_TYPE=") + type_str);
678
1585
  defines.push_back("DST_TYPE=f32");
679
1586
 
680
- if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
681
- key.src_type == GGML_TYPE_IQ4_NL) {
1587
+ if (key.src_type == GGML_TYPE_Q1_0) {
1588
+ defines.push_back("BLOCK_SIZE=128u");
1589
+ } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
1590
+ key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
682
1591
  defines.push_back("BLOCK_SIZE=32u");
683
1592
  } else if (key.src_type >= GGML_TYPE_Q2_K) {
684
1593
  defines.push_back("BLOCK_SIZE=256u");
@@ -705,7 +1614,8 @@ class ggml_webgpu_shader_lib {
705
1614
  }
706
1615
 
707
1616
  webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
708
- ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
1617
+ ggml_webgpu_scale_pipeline_key key = {};
1618
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
709
1619
 
710
1620
  auto it = scale_pipelines.find(key);
711
1621
  if (it != scale_pipelines.end()) {
@@ -725,14 +1635,189 @@ class ggml_webgpu_shader_lib {
725
1635
  auto processed = preprocessor.preprocess(wgsl_scale, defines);
726
1636
  auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
727
1637
  decisions->wg_size = context.max_wg_size;
1638
+ decisions->inplace = key.inplace;
728
1639
  webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
729
1640
  pipeline.context = decisions;
730
1641
  scale_pipelines[key] = pipeline;
731
1642
  return scale_pipelines[key];
732
1643
  }
733
1644
 
1645
+ webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
1646
+ ggml_webgpu_solve_tri_pipeline_key key = {};
1647
+ key.type = context.dst->type;
1648
+ key.n = (int) context.src0->ne[0];
1649
+ key.k = (int) context.src1->ne[0];
1650
+
1651
+ auto it = solve_tri_pipelines.find(key);
1652
+ if (it != solve_tri_pipelines.end()) {
1653
+ return it->second;
1654
+ }
1655
+
1656
+ std::vector<std::string> defines;
1657
+ std::string variant = "solve_tri";
1658
+
1659
+ switch (key.type) {
1660
+ case GGML_TYPE_F32:
1661
+ variant += "_f32";
1662
+ break;
1663
+ default:
1664
+ GGML_ABORT("Unsupported type for solve_tri shader");
1665
+ }
1666
+
1667
+ const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
1668
+ const uint32_t k_tile = wg_size;
1669
+ const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
1670
+ const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
1671
+
1672
+ defines.push_back(std::string("N=") + std::to_string(key.n));
1673
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1674
+ defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
1675
+ defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
1676
+
1677
+ auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
1678
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1679
+ decisions->wg_size = wg_size;
1680
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1681
+ pipeline.context = decisions;
1682
+ solve_tri_pipelines[key] = pipeline;
1683
+ return solve_tri_pipelines[key];
1684
+ }
1685
+
1686
+ webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
1687
+ ggml_webgpu_ssm_conv_pipeline_key key = {};
1688
+ key.type = context.dst->type;
1689
+ key.vectorized = context.src1->ne[0] == 4;
1690
+
1691
+ auto it = ssm_conv_pipelines.find(key);
1692
+ if (it != ssm_conv_pipelines.end()) {
1693
+ return it->second;
1694
+ }
1695
+
1696
+ std::vector<std::string> defines;
1697
+ std::string variant = "ssm_conv";
1698
+
1699
+ switch (key.type) {
1700
+ case GGML_TYPE_F32:
1701
+ variant += "_f32";
1702
+ break;
1703
+ default:
1704
+ GGML_ABORT("Unsupported type for ssm_conv shader");
1705
+ }
1706
+
1707
+ if (key.vectorized) {
1708
+ defines.push_back("VECTORIZED");
1709
+ variant += "_vec4";
1710
+ }
1711
+
1712
+ constexpr uint32_t block_size = 32u;
1713
+ constexpr uint32_t tokens_per_wg = 8u;
1714
+
1715
+ defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
1716
+ defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
1717
+
1718
+ auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
1719
+ auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
1720
+ decisions->block_size = block_size;
1721
+ decisions->tokens_per_wg = tokens_per_wg;
1722
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1723
+ pipeline.context = decisions;
1724
+ ssm_conv_pipelines[key] = pipeline;
1725
+ return ssm_conv_pipelines[key];
1726
+ }
1727
+
1728
+ webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
1729
+ ggml_webgpu_ssm_scan_pipeline_key key = {};
1730
+ key.type = context.dst->type;
1731
+ key.d_state = (int) context.src0->ne[0];
1732
+ key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
1733
+ ggml_webgpu_tensor_overlap(context.src1, context.src5);
1734
+
1735
+ auto it = ssm_scan_pipelines.find(key);
1736
+ if (it != ssm_scan_pipelines.end()) {
1737
+ return it->second;
1738
+ }
1739
+
1740
+ std::vector<std::string> defines;
1741
+ std::string variant = "ssm_scan";
1742
+
1743
+ switch (key.type) {
1744
+ case GGML_TYPE_F32:
1745
+ variant += "_f32";
1746
+ break;
1747
+ default:
1748
+ GGML_ABORT("Unsupported type for ssm_scan shader");
1749
+ }
1750
+
1751
+ const uint32_t wg_size = (uint32_t) key.d_state;
1752
+
1753
+ constexpr uint32_t tokens_per_tile = 4u;
1754
+
1755
+ defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
1756
+ defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
1757
+
1758
+ if (context.supports_subgroups) {
1759
+ defines.push_back("USE_SUBGROUP_REDUCTION");
1760
+ variant += "_sg_reduce";
1761
+ } else {
1762
+ variant += "_wg_reduce";
1763
+ }
1764
+
1765
+ if (key.xbc_overlap) {
1766
+ defines.push_back("XBC_OVERLAP");
1767
+ }
1768
+
1769
+ variant += "_d" + std::to_string(key.d_state);
1770
+
1771
+ auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
1772
+ auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
1773
+ decisions->wg_size = wg_size;
1774
+ decisions->tokens_per_tile = tokens_per_tile;
1775
+ decisions->xbc_overlap = key.xbc_overlap;
1776
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1777
+ pipeline.context = decisions;
1778
+ ssm_scan_pipelines[key] = pipeline;
1779
+ return ssm_scan_pipelines[key];
1780
+ }
1781
+
1782
+ webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
1783
+ ggml_webgpu_gated_delta_net_pipeline_key key = {};
1784
+ key.type = context.dst->type;
1785
+ key.s_v = (int) context.src2->ne[0];
1786
+ key.kda = context.src3->ne[0] == context.src2->ne[0];
1787
+
1788
+ auto it = gated_delta_net_pipelines.find(key);
1789
+ if (it != gated_delta_net_pipelines.end()) {
1790
+ return it->second;
1791
+ }
1792
+
1793
+ std::vector<std::string> defines;
1794
+ std::string variant = "gated_delta_net";
1795
+
1796
+ switch (key.type) {
1797
+ case GGML_TYPE_F32:
1798
+ variant += "_f32";
1799
+ break;
1800
+ default:
1801
+ GGML_ABORT("Unsupported type for gated_delta_net shader");
1802
+ }
1803
+
1804
+ if (key.kda) {
1805
+ defines.push_back("KDA");
1806
+ variant += "_kda";
1807
+ }
1808
+
1809
+ defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
1810
+ defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
1811
+
1812
+ auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
1813
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1814
+ gated_delta_net_pipelines[key] = pipeline;
1815
+ return gated_delta_net_pipelines[key];
1816
+ }
1817
+
734
1818
  webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
735
- ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
1819
+ ggml_webgpu_pad_pipeline_key key = {};
1820
+ key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
736
1821
 
737
1822
  auto it = pad_pipelines.find(key);
738
1823
  if (it != pad_pipelines.end()) {
@@ -758,16 +1843,54 @@ class ggml_webgpu_shader_lib {
758
1843
  return pad_pipelines[key];
759
1844
  }
760
1845
 
1846
+ webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
1847
+ ggml_webgpu_quantize_q8_pipeline_key key = {};
1848
+ key.src0_type = context.src0->type;
1849
+
1850
+ auto it = quantize_q8_pipelines.find(key);
1851
+ if (it != quantize_q8_pipelines.end()) {
1852
+ return it->second;
1853
+ }
1854
+ const char * shader_src = wgsl_quantize_q8;
1855
+ std::vector<std::string> defines;
1856
+ std::string variant = "quantize_q8";
1857
+
1858
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1859
+
1860
+ defines.push_back("SRC1_INNER_TYPE=f32");
1861
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
1862
+
1863
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1864
+ std::string src0_name = src0_traits->type_name;
1865
+ std::string type_upper = src0_name;
1866
+ variant += "_" + src0_name;
1867
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1868
+
1869
+ defines.push_back("MUL_ACC_" + type_upper);
1870
+ defines.push_back("Q8_1_T");
1871
+
1872
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
1873
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
1874
+
1875
+ auto processed = preprocessor.preprocess(shader_src, defines);
1876
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1877
+ decisions->wg_size = wg_size;
1878
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1879
+ pipeline.context = decisions;
1880
+ quantize_q8_pipelines[key] = pipeline;
1881
+ return quantize_q8_pipelines[key];
1882
+ }
1883
+
761
1884
  webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
762
- ggml_webgpu_mul_mat_vec_pipeline_key key = {
763
- .src0_type = context.src0->type,
764
- .src1_type = context.src1->type,
765
- // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
766
- .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
767
- (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
768
- 1 :
769
- 0,
770
- };
1885
+ ggml_webgpu_mul_mat_vec_pipeline_key key = {};
1886
+ key.src0_type = context.src0->type;
1887
+ key.src1_type = context.src1->type;
1888
+ key.vectorized = (context.src0->ne[0] % 4 == 0 &&
1889
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1890
+ 1 :
1891
+ 0;
1892
+ key.use_mmvq =
1893
+ ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
771
1894
 
772
1895
  auto it = mul_mat_vec_pipelines.find(key);
773
1896
  if (it != mul_mat_vec_pipelines.end()) {
@@ -775,7 +1898,8 @@ class ggml_webgpu_shader_lib {
775
1898
  }
776
1899
 
777
1900
  std::vector<std::string> defines;
778
- std::string variant = "mul_mat_vec";
1901
+ std::string variant = "mul_mat_vec";
1902
+ const char * shader_src = wgsl_mul_mat_vec;
779
1903
 
780
1904
  // src0 type (matrix row)
781
1905
  switch (context.src0->type) {
@@ -800,9 +1924,42 @@ class ggml_webgpu_shader_lib {
800
1924
 
801
1925
  defines.push_back("BYTE_HELPERS");
802
1926
  defines.push_back("MUL_ACC_" + type_upper);
803
-
804
- // For fast path we always dequantize from f16 inside the shader
805
- defines.push_back("SRC0_INNER_TYPE=f16");
1927
+ defines.push_back("U32_DEQUANT_HELPERS");
1928
+ defines.push_back("SRC0_INNER_TYPE=u32");
1929
+ switch (context.src0->type) {
1930
+ case GGML_TYPE_Q8_0:
1931
+ case GGML_TYPE_Q4_0:
1932
+ case GGML_TYPE_Q4_1:
1933
+ if (key.use_mmvq) {
1934
+ defines.push_back("LEGACY_QUANTS");
1935
+ }
1936
+ break;
1937
+ case GGML_TYPE_Q2_K:
1938
+ case GGML_TYPE_Q4_K:
1939
+ if (key.use_mmvq) {
1940
+ defines.push_back("K_QUANTS");
1941
+ }
1942
+ break;
1943
+ case GGML_TYPE_IQ1_S:
1944
+ case GGML_TYPE_IQ1_M:
1945
+ case GGML_TYPE_IQ2_S:
1946
+ case GGML_TYPE_IQ3_S:
1947
+ case GGML_TYPE_IQ4_NL:
1948
+ case GGML_TYPE_IQ4_XS:
1949
+ defines.push_back(type_upper + "_GRID");
1950
+ break;
1951
+ case GGML_TYPE_IQ2_XXS:
1952
+ case GGML_TYPE_IQ2_XS:
1953
+ case GGML_TYPE_IQ3_XXS:
1954
+ defines.push_back(type_upper + "_GRID");
1955
+ defines.push_back(type_upper + "_TABLES");
1956
+ break;
1957
+ case GGML_TYPE_MXFP4:
1958
+ defines.push_back(type_upper + "_LUT");
1959
+ break;
1960
+ default:
1961
+ break;
1962
+ }
806
1963
  break;
807
1964
  }
808
1965
  }
@@ -825,25 +1982,32 @@ class ggml_webgpu_shader_lib {
825
1982
  defines.push_back(key.vectorized ? "VEC" : "SCALAR");
826
1983
 
827
1984
  uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
828
- uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
829
1985
  uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
830
1986
 
831
- if (key.src0_type >= GGML_TYPE_Q2_K) {
832
- tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
1987
+ if (key.src0_type == GGML_TYPE_Q1_0) {
1988
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
1989
+ } else if (key.src0_type >= GGML_TYPE_Q2_K) {
833
1990
  outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
834
1991
  } else if (key.src0_type >= GGML_TYPE_Q4_0) {
835
- tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
836
1992
  outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
837
1993
  }
838
1994
 
1995
+ if (key.use_mmvq) {
1996
+ defines.push_back("MMVQ");
1997
+ defines.push_back("Q8_1_T");
1998
+ }
1999
+
839
2000
  defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
840
- defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
841
2001
  defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
2002
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
2003
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
2004
+ if (key.vectorized) {
2005
+ variant += "_vectorized";
2006
+ }
842
2007
 
843
- auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
2008
+ auto processed = preprocessor.preprocess(shader_src, defines);
844
2009
  auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
845
2010
  decisions->wg_size = wg_size;
846
- decisions->tile_k = tile_k;
847
2011
  decisions->outputs_per_wg = outputs_per_wg;
848
2012
  decisions->vec_size = key.vectorized ? 4 : 1;
849
2013
 
@@ -854,15 +2018,14 @@ class ggml_webgpu_shader_lib {
854
2018
  }
855
2019
 
856
2020
  webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
857
- ggml_webgpu_mul_mat_pipeline_key key = {
858
- .src0_type = context.src0->type,
859
- .src1_type = context.src1->type,
860
- .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
861
- (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
862
- 1 :
863
- 0,
864
- .use_subgroup_matrix = context.supports_subgroup_matrix
865
- };
2021
+ ggml_webgpu_mul_mat_pipeline_key key = {};
2022
+ key.src0_type = context.src0->type;
2023
+ key.src1_type = context.src1->type;
2024
+ key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
2025
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2026
+ 1 :
2027
+ 0;
2028
+ key.use_subgroup_matrix = context.supports_subgroup_matrix;
866
2029
 
867
2030
  auto it = mul_mat_fast_pipelines.find(key);
868
2031
  if (it != mul_mat_fast_pipelines.end()) {
@@ -915,9 +2078,30 @@ class ggml_webgpu_shader_lib {
915
2078
  defines.push_back("MUL_ACC_" + type_upper);
916
2079
  defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
917
2080
  defines.push_back("INIT_SRC1_SHMEM_FLOAT");
918
-
919
- // Use f16 inside the shader for quantized types
920
- defines.push_back("SRC0_INNER_TYPE=f16");
2081
+ defines.push_back("U32_DEQUANT_HELPERS");
2082
+ defines.push_back("SRC0_INNER_TYPE=u32");
2083
+
2084
+ switch (context.src0->type) {
2085
+ case GGML_TYPE_IQ1_S:
2086
+ case GGML_TYPE_IQ1_M:
2087
+ case GGML_TYPE_IQ4_NL:
2088
+ case GGML_TYPE_IQ4_XS:
2089
+ defines.push_back(type_upper + "_GRID");
2090
+ break;
2091
+ case GGML_TYPE_IQ2_XXS:
2092
+ case GGML_TYPE_IQ2_XS:
2093
+ case GGML_TYPE_IQ2_S:
2094
+ case GGML_TYPE_IQ3_XXS:
2095
+ case GGML_TYPE_IQ3_S:
2096
+ defines.push_back(type_upper + "_GRID");
2097
+ defines.push_back(type_upper + "_TABLES");
2098
+ break;
2099
+ case GGML_TYPE_MXFP4:
2100
+ defines.push_back(type_upper + "_LUT");
2101
+ break;
2102
+ default:
2103
+ break;
2104
+ }
921
2105
 
922
2106
  variant += std::string("_") + src0_name;
923
2107
  break;
@@ -927,13 +2111,22 @@ class ggml_webgpu_shader_lib {
927
2111
  // VEC/SCALAR controls
928
2112
  defines.push_back(key.vectorized ? "VEC" : "SCALAR");
929
2113
 
2114
+ const bool is_quant = ggml_is_quantized(context.src0->type);
2115
+
2116
+ uint32_t tile_k;
2117
+ if (key.use_subgroup_matrix) {
2118
+ tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
2119
+ } else {
2120
+ tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
2121
+ }
2122
+
930
2123
  // Tiles
931
2124
  defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
932
2125
  defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
933
- defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
934
2126
 
935
2127
  // Subgroup matrix specifics
936
2128
  if (key.use_subgroup_matrix) {
2129
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
937
2130
  defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
938
2131
  defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
939
2132
  defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
@@ -953,12 +2146,13 @@ class ggml_webgpu_shader_lib {
953
2146
  if (!key.use_subgroup_matrix) {
954
2147
  defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
955
2148
  defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
2149
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
956
2150
  }
957
2151
 
958
2152
  auto processed = preprocessor.preprocess(shader_src, defines);
959
2153
 
960
2154
  auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
961
- decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
2155
+ decisions->tile_k = tile_k;
962
2156
  decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
963
2157
  decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
964
2158
  decisions->use_subgroup_matrix = key.use_subgroup_matrix;
@@ -981,84 +2175,276 @@ class ggml_webgpu_shader_lib {
981
2175
  return mul_mat_fast_pipelines[key];
982
2176
  }
983
2177
 
984
- webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
985
- ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
986
- .src1_type = context.src1->type };
2178
+ webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
2179
+ auto it = mul_mat_id_gather_pipelines.find(1);
2180
+ if (it != mul_mat_id_gather_pipelines.end()) {
2181
+ return it->second;
2182
+ }
2183
+ std::vector<std::string> defines;
2184
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
987
2185
 
988
- auto it = mul_mat_legacy_pipelines.find(key);
989
- if (it != mul_mat_legacy_pipelines.end()) {
2186
+ auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
2187
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2188
+ decisions->wg_size = context.max_wg_size;
2189
+
2190
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
2191
+ pipeline.context = decisions;
2192
+ mul_mat_id_gather_pipelines[1] = pipeline;
2193
+ return pipeline;
2194
+ }
2195
+
2196
+ webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
2197
+ ggml_webgpu_mul_mat_id_pipeline_key key = {};
2198
+ key.src0_type = context.src0->type;
2199
+ key.src1_type = context.src1->type;
2200
+ key.n_experts = context.src0->ne[2];
2201
+ key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
2202
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2203
+ 1 :
2204
+ 0;
2205
+
2206
+ auto it = mul_mat_id_pipelines.find(key);
2207
+ if (it != mul_mat_id_pipelines.end()) {
990
2208
  return it->second;
991
2209
  }
992
2210
 
993
2211
  std::vector<std::string> defines;
994
- std::string variant = "mul_mat";
2212
+ std::string variant = "mul_mat_id";
2213
+ defines.push_back("MUL_MAT_ID");
995
2214
 
2215
+ // src1 type
996
2216
  switch (context.src1->type) {
997
2217
  case GGML_TYPE_F32:
998
- defines.push_back("SRC1_TYPE=f32");
999
- variant += "_f32";
2218
+ defines.push_back("SRC1_INNER_TYPE=f32");
1000
2219
  break;
1001
2220
  case GGML_TYPE_F16:
1002
- defines.push_back("SRC1_TYPE=f16");
1003
- variant += "_f16";
2221
+ defines.push_back("SRC1_INNER_TYPE=f16");
1004
2222
  break;
1005
2223
  default:
1006
- GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
2224
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
1007
2225
  }
1008
2226
 
2227
+ // src0 type
1009
2228
  const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1010
2229
  const char * src0_name = src0_traits->type_name;
1011
2230
 
1012
2231
  switch (context.src0->type) {
1013
2232
  case GGML_TYPE_F32:
1014
- defines.push_back("SRC0_TYPE=f32");
1015
- defines.push_back("FLOAT");
2233
+ defines.push_back("SRC0_INNER_TYPE=f32");
2234
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2235
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
1016
2236
  variant += "_f32";
1017
2237
  break;
1018
2238
  case GGML_TYPE_F16:
1019
- defines.push_back("SRC0_TYPE=f16");
1020
- defines.push_back("FLOAT");
2239
+ defines.push_back("SRC0_INNER_TYPE=f16");
2240
+ defines.push_back("INIT_SRC0_SHMEM_FLOAT");
2241
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
1021
2242
  variant += "_f16";
1022
2243
  break;
1023
2244
  default:
1024
2245
  {
1025
- // quantized types
1026
2246
  std::string type_upper = src0_name;
1027
2247
  std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1028
2248
 
1029
- defines.push_back(std::string("SRC0_TYPE=") + src0_name);
1030
2249
  defines.push_back("BYTE_HELPERS");
1031
- defines.push_back(type_upper + "_T");
1032
- defines.push_back(type_upper);
1033
- defines.push_back(type_upper + "_SCALE_MIN");
1034
- defines.push_back(type_upper + "_TABLES");
1035
- defines.push_back(type_upper + "_GRID");
2250
+ defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
2251
+ defines.push_back("INIT_SRC1_SHMEM_FLOAT");
2252
+ defines.push_back("U32_DEQUANT_HELPERS");
2253
+ defines.push_back("SRC0_INNER_TYPE=u32");
2254
+
2255
+ switch (context.src0->type) {
2256
+ case GGML_TYPE_IQ1_S:
2257
+ case GGML_TYPE_IQ1_M:
2258
+ case GGML_TYPE_IQ4_NL:
2259
+ case GGML_TYPE_IQ4_XS:
2260
+ defines.push_back(type_upper + "_GRID");
2261
+ break;
2262
+ case GGML_TYPE_IQ2_XXS:
2263
+ case GGML_TYPE_IQ2_XS:
2264
+ case GGML_TYPE_IQ2_S:
2265
+ case GGML_TYPE_IQ3_XXS:
2266
+ case GGML_TYPE_IQ3_S:
2267
+ defines.push_back(type_upper + "_GRID");
2268
+ defines.push_back(type_upper + "_TABLES");
2269
+ break;
2270
+ case GGML_TYPE_MXFP4:
2271
+ defines.push_back(type_upper + "_LUT");
2272
+ break;
2273
+ default:
2274
+ break;
2275
+ }
1036
2276
 
1037
2277
  variant += std::string("_") + src0_name;
1038
2278
  break;
1039
2279
  }
1040
2280
  }
1041
2281
 
1042
- auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
2282
+ // VEC/SCALAR controls
2283
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
1043
2284
 
1044
- auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1045
- decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
2285
+ // mul_mat_id is register-tile only.
2286
+ const uint32_t tile_k =
2287
+ ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
2288
+
2289
+ // Tiles
2290
+ defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
2291
+ defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
2292
+ defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
2293
+
2294
+ defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
2295
+ defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
2296
+
2297
+ // variant suffix for src1 type
2298
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2299
+ if (key.vectorized) {
2300
+ variant += "_vectorized";
2301
+ }
2302
+
2303
+ auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
2304
+
2305
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
2306
+ decisions->tile_k = tile_k;
2307
+ decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
2308
+ decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
2309
+ decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
2310
+ decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
2311
+ decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
2312
+
2313
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2314
+ pipeline.context = decisions;
2315
+ mul_mat_id_pipelines[key] = pipeline;
2316
+ return mul_mat_id_pipelines[key];
2317
+ }
2318
+
2319
+ webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
2320
+ ggml_webgpu_mul_mat_id_pipeline_key key = {};
2321
+ key.src0_type = context.src0->type;
2322
+ key.src1_type = context.src1->type;
2323
+ key.n_experts = context.src0->ne[2];
2324
+ key.vectorized = (context.src0->ne[0] % 4 == 0 &&
2325
+ (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2326
+ 1 :
2327
+ 0;
2328
+
2329
+ auto it = mul_mat_id_vec_pipelines.find(key);
2330
+ if (it != mul_mat_id_vec_pipelines.end()) {
2331
+ return it->second;
2332
+ }
2333
+
2334
+ std::vector<std::string> defines;
2335
+ std::string variant = "mul_mat_id_vec";
2336
+ const char * shader_src = wgsl_mul_mat_id_vec;
2337
+
2338
+ // src1 type
2339
+ switch (context.src1->type) {
2340
+ case GGML_TYPE_F32:
2341
+ defines.push_back("SRC1_INNER_TYPE=f32");
2342
+ break;
2343
+ case GGML_TYPE_F16:
2344
+ defines.push_back("SRC1_INNER_TYPE=f16");
2345
+ break;
2346
+ default:
2347
+ GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
2348
+ }
2349
+
2350
+ // src0 type
2351
+ switch (context.src0->type) {
2352
+ case GGML_TYPE_F32:
2353
+ defines.push_back("SRC0_INNER_TYPE=f32");
2354
+ defines.push_back("MUL_ACC_FLOAT");
2355
+ variant += "_f32";
2356
+ break;
2357
+ case GGML_TYPE_F16:
2358
+ defines.push_back("SRC0_INNER_TYPE=f16");
2359
+ defines.push_back("MUL_ACC_FLOAT");
2360
+ variant += "_f16";
2361
+ break;
2362
+ default:
2363
+ {
2364
+ // Quantized types: use helpers but accumulate in f16
2365
+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
2366
+ std::string src0_name = src0_traits->type_name;
2367
+ std::string type_upper = src0_name;
2368
+ variant += "_" + src0_name;
2369
+ std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
2370
+
2371
+ defines.push_back("BYTE_HELPERS");
2372
+ defines.push_back("MUL_ACC_" + type_upper);
2373
+ defines.push_back("U32_DEQUANT_HELPERS");
2374
+ defines.push_back("SRC0_INNER_TYPE=u32");
2375
+ switch (context.src0->type) {
2376
+ case GGML_TYPE_IQ1_S:
2377
+ case GGML_TYPE_IQ1_M:
2378
+ case GGML_TYPE_IQ2_S:
2379
+ case GGML_TYPE_IQ3_S:
2380
+ case GGML_TYPE_IQ4_NL:
2381
+ case GGML_TYPE_IQ4_XS:
2382
+ defines.push_back(type_upper + "_GRID");
2383
+ break;
2384
+ case GGML_TYPE_IQ2_XXS:
2385
+ case GGML_TYPE_IQ2_XS:
2386
+ case GGML_TYPE_IQ3_XXS:
2387
+ defines.push_back(type_upper + "_GRID");
2388
+ defines.push_back(type_upper + "_TABLES");
2389
+ break;
2390
+ case GGML_TYPE_MXFP4:
2391
+ defines.push_back(type_upper + "_LUT");
2392
+ break;
2393
+ default:
2394
+ break;
2395
+ }
2396
+ break;
2397
+ }
2398
+ }
2399
+
2400
+ // VEC/SCALAR controls
2401
+ defines.push_back(key.vectorized ? "VEC" : "SCALAR");
2402
+
2403
+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2404
+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
2405
+
2406
+ if (key.src0_type == GGML_TYPE_Q1_0) {
2407
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2408
+ } else if (key.src0_type >= GGML_TYPE_Q2_K) {
2409
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
2410
+ } else if (key.src0_type >= GGML_TYPE_Q4_0) {
2411
+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2412
+ }
2413
+
2414
+ // variant suffix for src1 type
2415
+ variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2416
+
2417
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2418
+ defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
2419
+ defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
2420
+ variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
2421
+ if (key.vectorized) {
2422
+ variant += "_vectorized";
2423
+ }
2424
+
2425
+ defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
2426
+
2427
+ auto processed = preprocessor.preprocess(shader_src, defines);
2428
+
2429
+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
2430
+ decisions->wg_size = wg_size;
2431
+ decisions->outputs_per_wg = outputs_per_wg;
1046
2432
 
1047
2433
  webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1048
2434
  pipeline.context = decisions;
1049
- mul_mat_legacy_pipelines[key] = pipeline;
1050
- return mul_mat_legacy_pipelines[key];
2435
+ mul_mat_id_vec_pipelines[key] = pipeline;
2436
+ return mul_mat_id_vec_pipelines[key];
1051
2437
  }
1052
2438
 
1053
2439
  webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
1054
2440
  const bool is_unary = context.dst->op == GGML_OP_UNARY;
1055
2441
  const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
1056
- ggml_webgpu_unary_pipeline_key key = {
1057
- .type = context.dst->type,
1058
- .op = op,
1059
- .is_unary = is_unary,
1060
- .inplace = context.inplace,
1061
- };
2442
+ ggml_webgpu_unary_pipeline_key key = {};
2443
+ key.type = context.dst->type;
2444
+ key.op = op;
2445
+ key.is_unary = is_unary;
2446
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
2447
+ key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
1062
2448
 
1063
2449
  auto it = unary_pipelines.find(key);
1064
2450
  if (it != unary_pipelines.end()) {
@@ -1088,25 +2474,88 @@ class ggml_webgpu_shader_lib {
1088
2474
  variant += "_inplace";
1089
2475
  }
1090
2476
 
2477
+ if (op == GGML_OP_TRI) {
2478
+ switch (key.ttype) {
2479
+ case GGML_TRI_TYPE_LOWER:
2480
+ defines.push_back("TRI_TYPE_LOWER");
2481
+ variant += "_tri_type_lower";
2482
+ break;
2483
+ case GGML_TRI_TYPE_LOWER_DIAG:
2484
+ defines.push_back("TRI_TYPE_LOWER_DIAG");
2485
+ variant += "_tri_type_lower_diag";
2486
+ break;
2487
+ case GGML_TRI_TYPE_UPPER:
2488
+ defines.push_back("TRI_TYPE_UPPER");
2489
+ variant += "_tri_type_upper";
2490
+ break;
2491
+ case GGML_TRI_TYPE_UPPER_DIAG:
2492
+ defines.push_back("TRI_TYPE_UPPER_DIAG");
2493
+ variant += "_tri_upper_diag";
2494
+ break;
2495
+ default:
2496
+ GGML_ABORT("Unsupported ggml_tri_type for unary shader");
2497
+ }
2498
+ }
2499
+
1091
2500
  defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1092
2501
 
1093
2502
  auto processed = preprocessor.preprocess(wgsl_unary, defines);
1094
2503
  auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1095
2504
  decisions->wg_size = context.max_wg_size;
2505
+ decisions->inplace = key.inplace;
1096
2506
  webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1097
2507
  pipeline.context = decisions;
1098
2508
  unary_pipelines[key] = pipeline;
1099
2509
  return unary_pipelines[key];
1100
2510
  }
1101
2511
 
2512
+ webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
2513
+ ggml_webgpu_rms_norm_mul_pipeline_key key = {};
2514
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2515
+ key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
2516
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
2517
+
2518
+ auto it = rms_norm_mul_pipelines.find(key);
2519
+ if (it != rms_norm_mul_pipelines.end()) {
2520
+ return it->second;
2521
+ }
2522
+
2523
+ std::vector<std::string> defines;
2524
+ std::string op_name = "RMS_NORM_MUL";
2525
+ std::string variant = op_name;
2526
+
2527
+ if (key.inplace) {
2528
+ defines.push_back("INPLACE");
2529
+ variant += "_inplace";
2530
+ } else if (key.overlap) {
2531
+ defines.push_back("OVERLAP");
2532
+ variant += "_overlap";
2533
+ } else if (key.src_overlap) {
2534
+ defines.push_back("SRC_OVERLAP");
2535
+ variant += "_src_overlap";
2536
+ }
2537
+
2538
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2539
+
2540
+ auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
2541
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
2542
+ pipeline_decisions->wg_size = context.max_wg_size;
2543
+ pipeline_decisions->inplace = key.inplace;
2544
+ pipeline_decisions->overlap = key.overlap;
2545
+ pipeline_decisions->src_overlap = key.src_overlap;
2546
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2547
+ pipeline.context = pipeline_decisions;
2548
+ rms_norm_mul_pipelines[key] = pipeline;
2549
+ return rms_norm_mul_pipelines[key];
2550
+ }
2551
+
1102
2552
  webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
1103
- ggml_webgpu_binary_pipeline_key key = {
1104
- .type = context.dst->type,
1105
- .op = context.dst->op,
1106
- .inplace = context.inplace,
1107
- .overlap = context.overlap,
1108
- .src_overlap = context.src_overlap,
1109
- };
2553
+ ggml_webgpu_binary_pipeline_key key = {};
2554
+ key.type = context.dst->type;
2555
+ key.op = context.dst->op;
2556
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2557
+ key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
2558
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
1110
2559
 
1111
2560
  auto it = binary_pipelines.find(key);
1112
2561
  if (it != binary_pipelines.end()) {
@@ -1145,19 +2594,54 @@ class ggml_webgpu_shader_lib {
1145
2594
 
1146
2595
  defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1147
2596
 
1148
- auto processed = preprocessor.preprocess(wgsl_binary, defines);
1149
- auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1150
- decisions->wg_size = context.max_wg_size;
2597
+ auto processed = preprocessor.preprocess(wgsl_binary, defines);
2598
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
2599
+ pipeline_decisions->wg_size = context.max_wg_size;
2600
+ pipeline_decisions->inplace = key.inplace;
2601
+ pipeline_decisions->overlap = key.overlap;
2602
+ pipeline_decisions->src_overlap = key.src_overlap;
2603
+
1151
2604
  webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1152
- pipeline.context = decisions;
2605
+ pipeline.context = pipeline_decisions;
1153
2606
  binary_pipelines[key] = pipeline;
1154
2607
  return binary_pipelines[key];
1155
2608
  }
1156
2609
 
2610
+ webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
2611
+ ggml_webgpu_add_id_pipeline_key key = {};
2612
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2613
+
2614
+ auto it = add_id_pipelines.find(key);
2615
+ if (it != add_id_pipelines.end()) {
2616
+ return it->second;
2617
+ }
2618
+
2619
+ std::vector<std::string> defines;
2620
+ std::string variant = "add_id";
2621
+ const char * shader_src = wgsl_add_id;
2622
+
2623
+ if (key.inplace) {
2624
+ defines.push_back("INPLACE");
2625
+ variant += "_inplace";
2626
+ }
2627
+
2628
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2629
+
2630
+ auto processed = preprocessor.preprocess(shader_src, defines);
2631
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2632
+ pipeline_decisions->wg_size = context.max_wg_size;
2633
+ pipeline_decisions->inplace = key.inplace;
2634
+
2635
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2636
+ pipeline.context = pipeline_decisions;
2637
+ add_id_pipelines[key] = pipeline;
2638
+ return pipeline;
2639
+ }
2640
+
1157
2641
  webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
1158
- ggml_webgpu_concat_pipeline_key key = {
1159
- .type = context.dst->type,
1160
- };
2642
+ ggml_webgpu_concat_pipeline_key key = {};
2643
+ key.type = context.dst->type;
2644
+ key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
1161
2645
 
1162
2646
  auto it = concat_pipelines.find(key);
1163
2647
  if (it != concat_pipelines.end()) {
@@ -1180,11 +2664,17 @@ class ggml_webgpu_shader_lib {
1180
2664
  GGML_ABORT("Unsupported type for concat shader");
1181
2665
  }
1182
2666
 
2667
+ if (key.src_overlap) {
2668
+ defines.push_back("SRC_OVERLAP");
2669
+ variant += "_src_overlap";
2670
+ }
2671
+
1183
2672
  defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1184
2673
 
1185
2674
  auto processed = preprocessor.preprocess(wgsl_concat, defines);
1186
- auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2675
+ auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
1187
2676
  decisions->wg_size = context.max_wg_size;
2677
+ decisions->src_overlap = key.src_overlap;
1188
2678
  webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1189
2679
  pipeline.context = decisions;
1190
2680
  concat_pipelines[key] = pipeline;
@@ -1192,9 +2682,8 @@ class ggml_webgpu_shader_lib {
1192
2682
  }
1193
2683
 
1194
2684
  webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
1195
- ggml_webgpu_repeat_pipeline_key key = {
1196
- .type = context.dst->type,
1197
- };
2685
+ ggml_webgpu_repeat_pipeline_key key = {};
2686
+ key.type = context.dst->type;
1198
2687
 
1199
2688
  auto it = repeat_pipelines.find(key);
1200
2689
  if (it != repeat_pipelines.end()) {
@@ -1233,102 +2722,551 @@ class ggml_webgpu_shader_lib {
1233
2722
  }
1234
2723
 
1235
2724
  webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
1236
- const bool has_mask = context.src3 != nullptr;
1237
- const bool has_sinks = context.src4 != nullptr;
1238
-
1239
- bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
1240
- (context.src1->ne[1] % context.sg_mat_n == 0);
1241
-
1242
- ggml_webgpu_flash_attn_pipeline_key key = {
1243
- .kv_type = context.src1->type,
1244
- .head_dim_qk = (uint32_t) context.src0->ne[0],
1245
- .head_dim_v = (uint32_t) context.src2->ne[0],
1246
- .kv_direct = kv_direct,
1247
- .has_mask = has_mask,
1248
- .has_sinks = has_sinks,
1249
- .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
1250
- };
2725
+ const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
2726
+ context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
2727
+ ggml_webgpu_flash_attn_decisions decisions = {};
2728
+ decisions.use_sg_matrix = can_use_subgroup_matrix;
2729
+ decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
2730
+
2731
+ ggml_webgpu_flash_attn_pipeline_key key = {};
2732
+ key.common =
2733
+ ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
2734
+ key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
2735
+ key.use_sg_matrix = decisions.use_sg_matrix;
2736
+
2737
+ const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
2738
+ context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
2739
+ key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
2740
+ GGML_ASSERT(max_kv_tile > 0);
2741
+
2742
+ decisions.kv_tile = decisions.use_sg_matrix ?
2743
+ std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
2744
+ std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
2745
+ decisions.wg_size =
2746
+ decisions.use_sg_matrix ?
2747
+ std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
2748
+ std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
2749
+ GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
2750
+
2751
+ if (key.common.kv_direct) {
2752
+ decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
2753
+ while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
2754
+ decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
2755
+ }
2756
+ }
1251
2757
 
1252
2758
  auto it = flash_attn_pipelines.find(key);
1253
2759
  if (it != flash_attn_pipelines.end()) {
1254
2760
  return it->second;
1255
2761
  }
1256
2762
 
2763
+ std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
2764
+ std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
2765
+ decisions.kv_tile, decisions.wg_size);
2766
+ const char * shader_src = nullptr;
2767
+ if (!key.use_sg_matrix) {
2768
+ shader_src = wgsl_flash_attn_tile;
2769
+ defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
2770
+ defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
2771
+ variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
2772
+ std::to_string(context.max_subgroup_size);
2773
+ } else {
2774
+ shader_src = wgsl_flash_attn;
2775
+ defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
2776
+ defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
2777
+ defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
2778
+ }
2779
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
2780
+ webgpu_pipeline pipeline =
2781
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
2782
+ pipeline.context = pipeline_decisions;
2783
+ flash_attn_pipelines[key] = pipeline;
2784
+ return flash_attn_pipelines[key];
2785
+ }
2786
+
2787
+ webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
2788
+ ggml_webgpu_flash_attn_vec_pipeline_key key = {};
2789
+ key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
2790
+
2791
+ auto it = flash_attn_vec_pipelines.find(key);
2792
+ if (it != flash_attn_vec_pipelines.end()) {
2793
+ return it->second;
2794
+ }
2795
+
2796
+ ggml_webgpu_flash_attn_vec_decisions decisions = {};
2797
+ decisions.kv_tile =
2798
+ ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
2799
+ key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
2800
+ decisions.wg_size = context.max_subgroup_size;
2801
+
2802
+ std::string variant = "flash_attn_vec";
2803
+ std::vector<std::string> defines =
2804
+ ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
2805
+ if (key.common.has_mask) {
2806
+ defines.push_back("BLK");
2807
+ variant.resize(variant.size() - (sizeof("_mask") - 1));
2808
+ variant += "_mask_blk";
2809
+ }
2810
+ uint32_t vec_ne = 1u;
2811
+ if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
2812
+ key.common.head_dim_qk == key.common.head_dim_v) {
2813
+ switch (key.common.head_dim_qk) {
2814
+ case 64:
2815
+ case 192:
2816
+ case 576:
2817
+ vec_ne = 2u;
2818
+ break;
2819
+ case 96:
2820
+ vec_ne = 4u;
2821
+ break;
2822
+ default:
2823
+ break;
2824
+ }
2825
+ }
2826
+ defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
2827
+
2828
+ auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
2829
+ webgpu_pipeline pipeline =
2830
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
2831
+ pipeline.context = pipeline_decisions;
2832
+ flash_attn_vec_pipelines[key] = pipeline;
2833
+ return flash_attn_vec_pipelines[key];
2834
+ }
2835
+
2836
+ webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
2837
+ ggml_webgpu_flash_attn_blk_pipeline_key key = {};
2838
+ key.kv_tile = kv_tile;
2839
+ auto it = flash_attn_blk_pipelines.find(key);
2840
+ if (it != flash_attn_blk_pipelines.end()) {
2841
+ return it->second;
2842
+ }
2843
+
1257
2844
  std::vector<std::string> defines;
1258
- std::string variant = "flash_attn";
2845
+ std::string variant = "flash_attn_vec_blk";
1259
2846
 
1260
- switch (key.kv_type) {
2847
+ defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
2848
+ variant += std::string("_kvt") + std::to_string(key.kv_tile);
2849
+
2850
+ uint32_t wg_size = 1;
2851
+ while ((wg_size << 1) <= context.max_wg_size) {
2852
+ wg_size <<= 1;
2853
+ }
2854
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2855
+ variant += std::string("_wg") + std::to_string(wg_size);
2856
+
2857
+ webgpu_pipeline pipeline =
2858
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
2859
+ flash_attn_blk_pipelines[key] = pipeline;
2860
+ return flash_attn_blk_pipelines[key];
2861
+ }
2862
+
2863
+ webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
2864
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
2865
+ key.head_dim_v = (uint32_t) context.src2->ne[0];
2866
+ key.dst_type = context.dst->type;
2867
+ key.wg_size = context.max_wg_size;
2868
+ auto it = flash_attn_vec_reduce_pipelines.find(key);
2869
+ if (it != flash_attn_vec_reduce_pipelines.end()) {
2870
+ return it->second;
2871
+ }
2872
+
2873
+ std::vector<std::string> defines;
2874
+ std::string variant = "flash_attn_vec_reduce";
2875
+
2876
+ switch (key.dst_type) {
1261
2877
  case GGML_TYPE_F32:
1262
- defines.push_back("KV_F32");
2878
+ defines.push_back("DST_F32");
1263
2879
  break;
1264
2880
  case GGML_TYPE_F16:
1265
- defines.push_back("KV_F16");
2881
+ defines.push_back("DST_F16");
1266
2882
  break;
1267
- case GGML_TYPE_Q4_0:
1268
- defines.push_back("KV_Q4_0");
2883
+ default:
2884
+ GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
2885
+ }
2886
+ variant += std::string("_dst") + ggml_type_name(key.dst_type);
2887
+
2888
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
2889
+ variant += std::string("_hsv") + std::to_string(key.head_dim_v);
2890
+
2891
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2892
+ variant += std::string("_wg") + std::to_string(context.max_wg_size);
2893
+
2894
+ webgpu_pipeline pipeline =
2895
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
2896
+ flash_attn_vec_reduce_pipelines[key] = pipeline;
2897
+ return flash_attn_vec_reduce_pipelines[key];
2898
+ }
2899
+
2900
+ webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
2901
+ ggml_webgpu_cpy_pipeline_key key = {};
2902
+ key.src_type = context.src0->type;
2903
+ key.dst_type = context.dst->type;
2904
+
2905
+ auto it = cpy_pipelines.find(key);
2906
+ if (it != cpy_pipelines.end()) {
2907
+ return it->second;
2908
+ }
2909
+
2910
+ std::vector<std::string> defines;
2911
+ std::string variant = "cpy";
2912
+
2913
+ switch (key.src_type) {
2914
+ case GGML_TYPE_F32:
2915
+ defines.push_back("SRC_F32");
2916
+ variant += "_f32";
1269
2917
  break;
1270
- case GGML_TYPE_Q8_0:
1271
- defines.push_back("KV_Q8_0");
2918
+ case GGML_TYPE_F16:
2919
+ defines.push_back("SRC_F16");
2920
+ variant += "_f16";
2921
+ break;
2922
+ default:
2923
+ GGML_ABORT("Unsupported src type for cpy shader");
2924
+ }
2925
+
2926
+ switch (key.dst_type) {
2927
+ case GGML_TYPE_F32:
2928
+ defines.push_back("DST_F32");
2929
+ variant += "_f32";
2930
+ break;
2931
+ case GGML_TYPE_F16:
2932
+ defines.push_back("DST_F16");
2933
+ variant += "_f16";
2934
+ break;
2935
+ case GGML_TYPE_I32:
2936
+ defines.push_back("DST_I32");
2937
+ variant += "_i32";
2938
+ break;
2939
+ default:
2940
+ GGML_ABORT("Unsupported dst type for cpy shader");
2941
+ }
2942
+
2943
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2944
+
2945
+ auto processed = preprocessor.preprocess(wgsl_cpy, defines);
2946
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2947
+ decisions->wg_size = context.max_wg_size;
2948
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2949
+ pipeline.context = decisions;
2950
+ cpy_pipelines[key] = pipeline;
2951
+ return cpy_pipelines[key];
2952
+ }
2953
+
2954
+ webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
2955
+ ggml_webgpu_glu_pipeline_key key = {};
2956
+ key.glu_op = ggml_get_glu_op(context.dst);
2957
+ key.type = context.dst->type;
2958
+ key.split = (context.src1 != nullptr);
2959
+
2960
+ auto it = glu_pipelines.find(key);
2961
+ if (it != glu_pipelines.end()) {
2962
+ return it->second;
2963
+ }
2964
+
2965
+ std::vector<std::string> defines;
2966
+ std::string variant = "glu";
2967
+
2968
+ switch (key.glu_op) {
2969
+ case GGML_GLU_OP_REGLU:
2970
+ defines.push_back("OP_REGLU");
2971
+ variant += "_reglu";
2972
+ break;
2973
+ case GGML_GLU_OP_GEGLU:
2974
+ defines.push_back("OP_GEGLU");
2975
+ variant += "_geglu";
2976
+ break;
2977
+ case GGML_GLU_OP_SWIGLU:
2978
+ defines.push_back("OP_SWIGLU");
2979
+ variant += "_swiglu";
2980
+ break;
2981
+ case GGML_GLU_OP_SWIGLU_OAI:
2982
+ defines.push_back("OP_SWIGLU_OAI");
2983
+ variant += "_swiglu_oai";
2984
+ break;
2985
+ case GGML_GLU_OP_GEGLU_ERF:
2986
+ defines.push_back("OP_GEGLU_ERF");
2987
+ variant += "_geglu_erf";
2988
+ break;
2989
+ case GGML_GLU_OP_GEGLU_QUICK:
2990
+ defines.push_back("OP_GEGLU_QUICK");
2991
+ variant += "_geglu_quick";
2992
+ break;
2993
+ default:
2994
+ GGML_ABORT("Unsupported GLU op");
2995
+ }
2996
+ switch (key.type) {
2997
+ case GGML_TYPE_F32:
2998
+ defines.push_back("TYPE_F32");
2999
+ variant += "_f32";
3000
+ break;
3001
+ case GGML_TYPE_F16:
3002
+ defines.push_back("TYPE_F16");
3003
+ variant += "_f16";
3004
+ break;
3005
+ default:
3006
+ GGML_ABORT("Unsupported type for GLU shader");
3007
+ }
3008
+
3009
+ if (key.split) {
3010
+ variant += "_split";
3011
+ } else {
3012
+ defines.push_back("NO_SPLIT");
3013
+ }
3014
+
3015
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3016
+
3017
+ auto processed = preprocessor.preprocess(wgsl_glu, defines);
3018
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3019
+ decisions->wg_size = context.max_wg_size;
3020
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3021
+ pipeline.context = decisions;
3022
+ glu_pipelines[key] = pipeline;
3023
+ return glu_pipelines[key];
3024
+ }
3025
+
3026
+ webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
3027
+ ggml_webgpu_rope_pipeline_key key = {};
3028
+ key.type = context.dst->type;
3029
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
3030
+ key.has_ff = (context.src2 != nullptr);
3031
+
3032
+ auto it = rope_pipelines.find(key);
3033
+ if (it != rope_pipelines.end()) {
3034
+ return it->second;
3035
+ }
3036
+
3037
+ std::vector<std::string> defines;
3038
+ std::string variant = "rope";
3039
+
3040
+ switch (key.type) {
3041
+ case GGML_TYPE_F32:
3042
+ defines.push_back("TYPE_F32");
3043
+ variant += "_f32";
3044
+ break;
3045
+ case GGML_TYPE_F16:
3046
+ defines.push_back("TYPE_F16");
3047
+ variant += "_f16";
1272
3048
  break;
1273
3049
  default:
1274
- GGML_ABORT("Unsupported KV type for flash attention shader");
3050
+ GGML_ABORT("Unsupported type for ROPE shader");
3051
+ }
3052
+
3053
+ if (key.inplace) {
3054
+ defines.push_back("INPLACE");
3055
+ variant += "_inplace";
3056
+ }
3057
+
3058
+ if (key.has_ff) {
3059
+ defines.push_back("FF_FUNC");
3060
+ variant += "_ff";
1275
3061
  }
1276
- variant += std::string("_") + ggml_type_name(key.kv_type);
3062
+
3063
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3064
+
3065
+ auto processed = preprocessor.preprocess(wgsl_rope, defines);
3066
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3067
+ decisions->wg_size = context.max_wg_size;
3068
+ decisions->inplace = key.inplace;
3069
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3070
+ pipeline.context = decisions;
3071
+ rope_pipelines[key] = pipeline;
3072
+ return rope_pipelines[key];
3073
+ }
3074
+
3075
+ webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
3076
+ ggml_webgpu_soft_max_pipeline_key key = {};
3077
+ key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
3078
+ key.has_mask = (context.src1 != nullptr);
3079
+ key.has_sink = (context.src2 != nullptr);
3080
+ key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
3081
+
3082
+ auto it = soft_max_pipelines.find(key);
3083
+ if (it != soft_max_pipelines.end()) {
3084
+ return it->second;
3085
+ }
3086
+
3087
+ std::vector<std::string> defines;
3088
+ std::string variant = "soft_max";
1277
3089
 
1278
3090
  if (key.has_mask) {
1279
- defines.push_back("MASK");
1280
- variant += "_mask";
3091
+ defines.push_back("HAS_MASK");
3092
+ switch (key.mask_type) {
3093
+ case GGML_TYPE_F32:
3094
+ defines.push_back("MASK_F32");
3095
+ variant += "_mask_f32";
3096
+ break;
3097
+ case GGML_TYPE_F16:
3098
+ defines.push_back("MASK_F16");
3099
+ variant += "_mask_f16";
3100
+ break;
3101
+ default:
3102
+ GGML_ABORT("Unsupported type for SOFT_MAX shader");
3103
+ }
1281
3104
  }
1282
- if (key.has_sinks) {
1283
- defines.push_back("SINKS");
1284
- variant += "_sinks";
3105
+
3106
+ if (key.has_sink) {
3107
+ defines.push_back("HAS_SINK");
3108
+ variant += "_sink";
1285
3109
  }
1286
- if (key.uses_logit_softcap) {
1287
- defines.push_back("LOGIT_SOFTCAP");
1288
- variant += "_lgsc";
3110
+
3111
+ if (key.inplace) {
3112
+ defines.push_back("INPLACE");
3113
+ variant += "_inplace";
1289
3114
  }
1290
- if (key.kv_direct) {
1291
- defines.push_back("KV_DIRECT");
1292
- variant += "_kvdirect";
3115
+
3116
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3117
+
3118
+ auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
3119
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3120
+ decisions->wg_size = context.max_wg_size;
3121
+ decisions->inplace = key.inplace;
3122
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3123
+ pipeline.context = decisions;
3124
+ soft_max_pipelines[key] = pipeline;
3125
+ return soft_max_pipelines[key];
3126
+ }
3127
+
3128
+ webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
3129
+ ggml_webgpu_conv2d_pipeline_key key = {};
3130
+ key.weight_type = context.src0->type;
3131
+ key.input_type = context.src1->type;
3132
+ key.output_type = context.dst->type;
3133
+
3134
+ auto it = conv2d_pipelines.find(key);
3135
+ if (it != conv2d_pipelines.end()) {
3136
+ return it->second;
1293
3137
  }
1294
3138
 
1295
- defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
1296
- variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
3139
+ std::vector<std::string> defines;
3140
+ std::string variant = "conv_2d";
3141
+
3142
+ auto push_type_defines = [&](const char * prefix, ggml_type type) {
3143
+ std::string s_prefix = prefix;
3144
+ if (type == GGML_TYPE_F32) {
3145
+ defines.push_back(s_prefix + "_F32");
3146
+ } else if (type == GGML_TYPE_F16) {
3147
+ defines.push_back(s_prefix + "_F16");
3148
+ } else {
3149
+ GGML_ABORT("Unsupported type for CONV_2D shader");
3150
+ }
3151
+ };
1297
3152
 
1298
- defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
1299
- variant += std::string("_hsv") + std::to_string(key.head_dim_v);
3153
+ push_type_defines("WEIGHT", key.weight_type);
3154
+ push_type_defines("INPUT", key.input_type);
3155
+ push_type_defines("OUTPUT", key.output_type);
3156
+
3157
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1300
3158
 
1301
- defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
1302
- defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
1303
- defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
1304
-
1305
- uint32_t q_tile = context.sg_mat_m;
1306
- uint32_t kv_tile =
1307
- std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
1308
- context.wg_mem_limit_bytes, context.max_subgroup_size }),
1309
- context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1310
- if (key.kv_direct) {
1311
- while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
1312
- kv_tile -= context.sg_mat_n;
3159
+ auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
3160
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3161
+ decisions->wg_size = context.max_wg_size;
3162
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3163
+ pipeline.context = decisions;
3164
+ conv2d_pipelines[key] = pipeline;
3165
+ return conv2d_pipelines[key];
3166
+ }
3167
+
3168
+ webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
3169
+ ggml_webgpu_im2col_pipeline_key key = {};
3170
+ key.input_type = context.src1->type;
3171
+ key.output_type = context.dst->type;
3172
+
3173
+ auto it = im2col_pipelines.find(key);
3174
+ if (it != im2col_pipelines.end()) {
3175
+ return it->second;
3176
+ }
3177
+
3178
+ std::vector<std::string> defines;
3179
+ std::string variant = "im2col";
3180
+
3181
+ auto push_type_defines = [&](const char * prefix, ggml_type type) {
3182
+ std::string s_prefix = prefix;
3183
+ if (type == GGML_TYPE_F32) {
3184
+ defines.push_back(s_prefix + "_F32");
3185
+ } else if (type == GGML_TYPE_F16) {
3186
+ defines.push_back(s_prefix + "_F16");
3187
+ } else {
3188
+ GGML_ABORT("Unsupported type for IM2COL shader");
1313
3189
  }
3190
+ };
3191
+
3192
+ push_type_defines("INPUT", key.input_type);
3193
+ push_type_defines("OUTPUT", key.output_type);
3194
+
3195
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3196
+
3197
+ auto processed = preprocessor.preprocess(wgsl_im2col, defines);
3198
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3199
+ decisions->wg_size = context.max_wg_size;
3200
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3201
+ pipeline.context = decisions;
3202
+ im2col_pipelines[key] = pipeline;
3203
+ return im2col_pipelines[key];
3204
+ }
3205
+
3206
+ webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
3207
+ const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
3208
+ const uint32_t base_mode = mode_flags & 0xFFu;
3209
+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
3210
+
3211
+ ggml_webgpu_upscale_pipeline_key key = {};
3212
+ key.input_type = context.src0->type;
3213
+ key.output_type = context.dst->type;
3214
+ key.base_mode = base_mode;
3215
+ key.antialias = antialias;
3216
+
3217
+ auto it = upscale_pipelines.find(key);
3218
+ if (it != upscale_pipelines.end()) {
3219
+ return it->second;
1314
3220
  }
1315
3221
 
1316
- defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
1317
- defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
3222
+ std::vector<std::string> defines;
3223
+ std::string variant = "upscale";
1318
3224
 
1319
- uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1320
- defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
3225
+ if (key.input_type == GGML_TYPE_F16) {
3226
+ defines.push_back("SRC_F16");
3227
+ variant += "_src_f16";
3228
+ } else {
3229
+ variant += "_src_f32";
3230
+ }
1321
3231
 
1322
- auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
1323
- auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
1324
- decisions->q_tile = q_tile;
1325
- decisions->kv_tile = kv_tile;
1326
- decisions->wg_size = wg_size;
3232
+ if (key.output_type == GGML_TYPE_F16) {
3233
+ defines.push_back("DST_F16");
3234
+ variant += "_dst_f16";
3235
+ } else {
3236
+ variant += "_dst_f32";
3237
+ }
1327
3238
 
1328
- webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1329
- pipeline.context = decisions;
1330
- flash_attn_pipelines[key] = pipeline;
1331
- return flash_attn_pipelines[key];
3239
+ switch (base_mode) {
3240
+ case GGML_SCALE_MODE_NEAREST:
3241
+ defines.push_back("NEAREST");
3242
+ variant += "_nearest";
3243
+ break;
3244
+ case GGML_SCALE_MODE_BILINEAR:
3245
+ defines.push_back("BILINEAR");
3246
+ variant += "_bilinear";
3247
+ break;
3248
+ case GGML_SCALE_MODE_BICUBIC:
3249
+ defines.push_back("BICUBIC");
3250
+ variant += "_bicubic";
3251
+ break;
3252
+ default:
3253
+ GGML_ABORT("Unsupported upscale mode");
3254
+ }
3255
+
3256
+ if (antialias) {
3257
+ defines.push_back("ANTIALIAS");
3258
+ variant += "_aa";
3259
+ }
3260
+
3261
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
3262
+
3263
+ auto processed = preprocessor.preprocess(wgsl_upscale, defines);
3264
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
3265
+ decisions->wg_size = context.max_wg_size;
3266
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
3267
+ pipeline.context = decisions;
3268
+ upscale_pipelines[key] = pipeline;
3269
+ return upscale_pipelines[key];
1332
3270
  }
1333
3271
 
1334
3272
  private:
@@ -1350,25 +3288,6 @@ class ggml_webgpu_shader_lib {
1350
3288
  pipeline_desc.layout = nullptr; // nullptr means auto layout
1351
3289
  return { device.CreateComputePipeline(&pipeline_desc), label };
1352
3290
  }
1353
-
1354
- static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
1355
- const size_t limit_bytes = context.wg_mem_limit_bytes;
1356
- const size_t q_tile = context.sg_mat_m;
1357
- const size_t base_q_bytes =
1358
- (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
1359
- 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
1360
- size_t bytes_per_kv = 0;
1361
- if (!context.key.kv_direct) {
1362
- bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
1363
- }
1364
- if (context.key.has_mask) {
1365
- bytes_per_kv += q_tile;
1366
- }
1367
- bytes_per_kv += q_tile;
1368
- bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
1369
- const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
1370
- return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
1371
- }
1372
3291
  };
1373
3292
 
1374
3293
  #endif // GGML_WEBGPU_SHADER_LIB_HPP